Simplify arguments for ConstellationNet

This commit is contained in:
Mattéo Delabre 2019-12-18 10:26:35 -05:00
parent a7e9dd2230
commit d8a140d793
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
3 changed files with 27 additions and 26 deletions

View File

@ -13,9 +13,8 @@ class ConstellationNet(nn.Module):
def __init__(
self,
order=2,
encoder_layers_sizes=(),
decoder_layers_sizes=(),
channel_model=GaussianChannel()
encoder_layers=(),
decoder_layers=(),
):
"""
Create an autoencoder.
@ -23,56 +22,54 @@ class ConstellationNet(nn.Module):
:param order: Order of the constellation, i.e. the number of messages
that are to be transmitted, or equivalently the number of symbols whose
placements in the constellation have to be learned.
:param encoder_layers_sizes: Shape of the encoders hidden layers. The
:param encoder_layers: Shape of the encoders hidden layers. The
size of this sequence is the number of hidden layers, with each element
being a number which specifies the number of neurons in its channel.
:param decoder_layers_sizes: Shape of the decoders hidden layers. Uses
:param decoder_layers: Shape of the decoders hidden layers. Uses
the same convention as `encoder_layers_sizes` above.
:param channel_model: Instance of the channel model to use between the
encoder and decoder network.
"""
super().__init__()
self.order = order
# Build the encoder network taking a one-hot encoded message as input
# and outputting an I/Q vector. The network additionally uses hidden
# layers as specified in `encoder_layers_sizes`
# layers as specified in `encoder_layers`
prev_layer_size = order
encoder_layers = []
encoder_layers_list = []
for layer_size in encoder_layers_sizes:
encoder_layers.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers.append(nn.ReLU())
encoder_layers.append(nn.BatchNorm1d(layer_size))
for layer_size in encoder_layers:
encoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers_list.append(nn.ReLU())
encoder_layers_list.append(nn.BatchNorm1d(layer_size))
prev_layer_size = layer_size
encoder_layers += [
encoder_layers_list += [
nn.Linear(prev_layer_size, 2),
NormalizePower(),
]
self.encoder = nn.Sequential(*encoder_layers)
self.channel = channel_model
self.encoder = nn.Sequential(*encoder_layers_list)
self.channel = GaussianChannel()
# Build the decoder network taking the noisy I/Q vector received from
# the channel as input and outputting a probability vector for each
# original message. The network additionally uses hidden layers as
# specified in `decoder_layers_sizes`
# specified in `decoder_layers`
prev_layer_size = 2
decoder_layers = []
decoder_layers_list = []
for layer_size in decoder_layers_sizes:
decoder_layers.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers.append(nn.ReLU())
decoder_layers.append(nn.BatchNorm1d(layer_size))
for layer_size in decoder_layers:
decoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers_list.append(nn.ReLU())
decoder_layers_list.append(nn.BatchNorm1d(layer_size))
prev_layer_size = layer_size
# Softmax is not used at the end of the network because the
# CrossEntropyLoss criterion is used for training, which includes
# LogSoftmax
decoder_layers.append(nn.Linear(prev_layer_size, order))
decoder_layers_list.append(nn.Linear(prev_layer_size, order))
self.decoder = nn.Sequential(*decoder_layers)
self.decoder = nn.Sequential(*decoder_layers_list)
def forward(self, x):
"""
@ -88,7 +85,7 @@ class ConstellationNet(nn.Module):
def get_constellation(self):
"""
Extract symbol constellation out of the trained encoder.
Extract the symbol constellation out of the trained encoder.
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
vector that is the result of encoding the nᵗʰ message.

View File

@ -4,6 +4,10 @@ import math
class GaussianChannel(nn.Module):
"""
Simulated communication channel that assumes a Gaussian noise model for
taking in account interference.
"""
def __init__(self):
super().__init__()

View File

@ -5,7 +5,7 @@ from matplotlib import pyplot
from mpl_toolkits.axisartist.axislines import SubplotZero
import warnings
torch.manual_seed(42)
torch.manual_seed(57)
# Number of symbols to learn
order = 16