import torch.nn as nn from .GaussianChannel import GaussianChannel class ConstellationNet(nn.Module): def __init__( self, order=2, encoder_layers_sizes=(), decoder_layers_sizes=(), channel_model=GaussianChannel() ): """ Create an encoder-decoder network to automatically shape a constellation of symbols for efficient communication over an optical fiber channel. :param order: Order of the constellation, i.e. the number of messages that are to be transmitted, or equivalently the number of symbols whose placements in the constellation have to be learned. :param encoder_layers_sizes: Shape of the encoder’s hidden layers. The size of this sequence is the number of hidden layers, with each element being a number which specifies the number of neurons in its channel. :param decoder_layers_sizes: Shape of the decoder’s hidden layers. Uses the same convention as `encoder_layers_sizes` above. :param channel_model: Instance of the channel model to use between the encoder and decoder network. """ super().__init__() # Build the encoder network taking a one-hot encoded message as input # and outputting an I/Q vector. The network additionally uses hidden # layers as specified in `encoder_layers_sizes` prev_layer_size = order encoder_layers = [] for layer_size in encoder_layers_sizes: encoder_layers.append(nn.Linear(prev_layer_size, layer_size)) encoder_layers.append(nn.Tanh()) prev_layer_size = layer_size encoder_layers += [ nn.Linear(prev_layer_size, 2), nn.Tanh(), ] self.encoder = nn.Sequential(*encoder_layers) self.channel = channel_model # Build the decoder network taking the noisy I/Q vector received from # the channel as input and outputting a probability vector for each # original message. The network additionally uses hidden layers as # specified in `decoder_layers_sizes` prev_layer_size = 2 decoder_layers = [] for layer_size in decoder_layers_sizes: decoder_layers.append(nn.Linear(prev_layer_size, layer_size)) decoder_layers.append(nn.Tanh()) prev_layer_size = layer_size # Softmax is not used at the end of the network because the # CrossEntropyLoss criterion is used for training, which includes # LogSoftmax decoder_layers.append(nn.Linear(prev_layer_size, order)) self.decoder = nn.Sequential(*decoder_layers) def forward(self, x): """ Perform encoding and decoding of an input vector and compute its reconstructed vector. :param x: Original one-hot encoded data. :return: Reconstructed vector. """ symbol = self.encoder(x) noisy_symbol = self.channel(symbol) return self.decoder(noisy_symbol)