constellationnet/constellation/NormalizePower.py

21 lines
509 B
Python

import torch.nn as nn
import torch
class NormalizePower(nn.Module):
"""
Layer for normalizing a batch of vectors so that their average length is 1.
:attr epsilon: Minimum mean length to avoid division by zero.
"""
epsilon = 1e-12
def forward(self, x):
average_power = (torch.norm(x, dim=1) ** 2).mean()
average_power = torch.max(torch.tensor([
NormalizePower.epsilon,
average_power
]))
return x * torch.rsqrt(average_power)