21 lines
509 B
Python
21 lines
509 B
Python
|
import torch.nn as nn
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class NormalizePower(nn.Module):
|
||
|
"""
|
||
|
Layer for normalizing a batch of vectors so that their average length is 1.
|
||
|
|
||
|
:attr epsilon: Minimum mean length to avoid division by zero.
|
||
|
"""
|
||
|
epsilon = 1e-12
|
||
|
|
||
|
def forward(self, x):
|
||
|
average_power = (torch.norm(x, dim=1) ** 2).mean()
|
||
|
average_power = torch.max(torch.tensor([
|
||
|
NormalizePower.epsilon,
|
||
|
average_power
|
||
|
]))
|
||
|
|
||
|
return x * torch.rsqrt(average_power)
|