Add normalization layer
This commit is contained in:
parent
b97ba61f42
commit
3365603614
|
@ -1,8 +1,13 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .GaussianChannel import GaussianChannel
|
from .GaussianChannel import GaussianChannel
|
||||||
|
from .NormalizePower import NormalizePower
|
||||||
|
|
||||||
|
|
||||||
class ConstellationNet(nn.Module):
|
class ConstellationNet(nn.Module):
|
||||||
|
"""
|
||||||
|
Autoencoder network to automatically shape a constellation of symbols for
|
||||||
|
efficient communication over an optical fiber channel.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
order=2,
|
order=2,
|
||||||
|
@ -11,9 +16,7 @@ class ConstellationNet(nn.Module):
|
||||||
channel_model=GaussianChannel()
|
channel_model=GaussianChannel()
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create an encoder-decoder network to automatically shape a
|
Create an autoencoder.
|
||||||
constellation of symbols for efficient communication over an optical
|
|
||||||
fiber channel.
|
|
||||||
|
|
||||||
:param order: Order of the constellation, i.e. the number of messages
|
:param order: Order of the constellation, i.e. the number of messages
|
||||||
that are to be transmitted, or equivalently the number of symbols whose
|
that are to be transmitted, or equivalently the number of symbols whose
|
||||||
|
@ -39,7 +42,10 @@ class ConstellationNet(nn.Module):
|
||||||
encoder_layers.append(nn.SELU())
|
encoder_layers.append(nn.SELU())
|
||||||
prev_layer_size = layer_size
|
prev_layer_size = layer_size
|
||||||
|
|
||||||
encoder_layers.append(nn.Linear(prev_layer_size, 2))
|
encoder_layers += [
|
||||||
|
nn.Linear(prev_layer_size, 2),
|
||||||
|
NormalizePower(),
|
||||||
|
]
|
||||||
|
|
||||||
self.encoder = nn.Sequential(*encoder_layers)
|
self.encoder = nn.Sequential(*encoder_layers)
|
||||||
self.channel = channel_model
|
self.channel = channel_model
|
||||||
|
|
|
@ -37,7 +37,7 @@ def channel_OSNR():
|
||||||
|
|
||||||
def Const_Points_Pow(IQ):
|
def Const_Points_Pow(IQ):
|
||||||
"""
|
"""
|
||||||
Calculate the average power of a set of vectors.
|
Calculate the average power of a constellation.
|
||||||
"""
|
"""
|
||||||
p_enc_avg = (torch.norm(IQ, dim=1) ** 2).mean()
|
p_enc_avg = (torch.norm(IQ, dim=1) ** 2).mean()
|
||||||
p_enc_avg_dB = 10 * torch.log10(p_enc_avg)
|
p_enc_avg_dB = 10 * torch.log10(p_enc_avg)
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizePower(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for normalizing a batch of vectors so that their average length is 1.
|
||||||
|
|
||||||
|
:attr epsilon: Minimum mean length to avoid division by zero.
|
||||||
|
"""
|
||||||
|
epsilon = 1e-12
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
average_power = (torch.norm(x, dim=1) ** 2).mean()
|
||||||
|
average_power = torch.max(torch.tensor([
|
||||||
|
NormalizePower.epsilon,
|
||||||
|
average_power
|
||||||
|
]))
|
||||||
|
|
||||||
|
return x * torch.rsqrt(average_power)
|
Loading…
Reference in New Issue