21 lines
509 B
Python
21 lines
509 B
Python
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class NormalizePower(nn.Module):
|
|
"""
|
|
Layer for normalizing a batch of vectors so that their average length is 1.
|
|
|
|
:attr epsilon: Minimum mean length to avoid division by zero.
|
|
"""
|
|
epsilon = 1e-12
|
|
|
|
def forward(self, x):
|
|
average_power = (torch.norm(x, dim=1) ** 2).mean()
|
|
average_power = torch.max(torch.tensor([
|
|
NormalizePower.epsilon,
|
|
average_power
|
|
]))
|
|
|
|
return x * torch.rsqrt(average_power)
|