Mattéo Delabre
4 years ago
3 changed files with 31 additions and 5 deletions
@ -0,0 +1,20 @@ |
|||
import torch.nn as nn |
|||
import torch |
|||
|
|||
|
|||
class NormalizePower(nn.Module): |
|||
""" |
|||
Layer for normalizing a batch of vectors so that their average length is 1. |
|||
|
|||
:attr epsilon: Minimum mean length to avoid division by zero. |
|||
""" |
|||
epsilon = 1e-12 |
|||
|
|||
def forward(self, x): |
|||
average_power = (torch.norm(x, dim=1) ** 2).mean() |
|||
average_power = torch.max(torch.tensor([ |
|||
NormalizePower.epsilon, |
|||
average_power |
|||
])) |
|||
|
|||
return x * torch.rsqrt(average_power) |
Loading…
Reference in new issue