Browse Source

Add normalization layer

master
Mattéo Delabre 4 years ago
parent
commit
3365603614
Signed by: matteo GPG Key ID: AE3FBD02DC583ABB
  1. 14
      constellation/ConstellationNet.py
  2. 2
      constellation/GaussianChannel.py
  3. 20
      constellation/NormalizePower.py

14
constellation/ConstellationNet.py

@ -1,8 +1,13 @@
import torch.nn as nn
from .GaussianChannel import GaussianChannel
from .NormalizePower import NormalizePower
class ConstellationNet(nn.Module):
"""
Autoencoder network to automatically shape a constellation of symbols for
efficient communication over an optical fiber channel.
"""
def __init__(
self,
order=2,
@ -11,9 +16,7 @@ class ConstellationNet(nn.Module):
channel_model=GaussianChannel()
):
"""
Create an encoder-decoder network to automatically shape a
constellation of symbols for efficient communication over an optical
fiber channel.
Create an autoencoder.
:param order: Order of the constellation, i.e. the number of messages
that are to be transmitted, or equivalently the number of symbols whose
@ -39,7 +42,10 @@ class ConstellationNet(nn.Module):
encoder_layers.append(nn.SELU())
prev_layer_size = layer_size
encoder_layers.append(nn.Linear(prev_layer_size, 2))
encoder_layers += [
nn.Linear(prev_layer_size, 2),
NormalizePower(),
]
self.encoder = nn.Sequential(*encoder_layers)
self.channel = channel_model

2
constellation/GaussianChannel.py

@ -37,7 +37,7 @@ def channel_OSNR():
def Const_Points_Pow(IQ):
"""
Calculate the average power of a set of vectors.
Calculate the average power of a constellation.
"""
p_enc_avg = (torch.norm(IQ, dim=1) ** 2).mean()
p_enc_avg_dB = 10 * torch.log10(p_enc_avg)

20
constellation/NormalizePower.py

@ -0,0 +1,20 @@
import torch.nn as nn
import torch
class NormalizePower(nn.Module):
"""
Layer for normalizing a batch of vectors so that their average length is 1.
:attr epsilon: Minimum mean length to avoid division by zero.
"""
epsilon = 1e-12
def forward(self, x):
average_power = (torch.norm(x, dim=1) ** 2).mean()
average_power = torch.max(torch.tensor([
NormalizePower.epsilon,
average_power
]))
return x * torch.rsqrt(average_power)
Loading…
Cancel
Save