import torch.nn as nn import torch class NormalizePower(nn.Module): """ Layer for normalizing a batch of vectors so that their average length is 1. :attr epsilon: Minimum mean length to avoid division by zero. """ epsilon = 1e-12 def forward(self, x): average_power = (torch.norm(x, dim=1) ** 2).mean() average_power = torch.max(torch.tensor([ NormalizePower.epsilon, average_power ])) return x * torch.rsqrt(average_power)