2019-12-13 17:11:09 +00:00
|
|
|
|
import torch.nn as nn
|
2019-12-16 00:42:50 +00:00
|
|
|
|
import torch
|
2019-12-15 04:04:35 +00:00
|
|
|
|
from .GaussianChannel import GaussianChannel
|
2019-12-15 14:42:33 +00:00
|
|
|
|
from .NormalizePower import NormalizePower
|
2019-12-16 00:42:50 +00:00
|
|
|
|
from . import util
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConstellationNet(nn.Module):
|
2019-12-15 14:42:33 +00:00
|
|
|
|
"""
|
|
|
|
|
Autoencoder network to automatically shape a constellation of symbols for
|
|
|
|
|
efficient communication over an optical fiber channel.
|
|
|
|
|
"""
|
2019-12-13 17:11:09 +00:00
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
order=2,
|
|
|
|
|
encoder_layers_sizes=(),
|
2019-12-15 04:04:35 +00:00
|
|
|
|
decoder_layers_sizes=(),
|
|
|
|
|
channel_model=GaussianChannel()
|
2019-12-13 17:11:09 +00:00
|
|
|
|
):
|
|
|
|
|
"""
|
2019-12-15 14:42:33 +00:00
|
|
|
|
Create an autoencoder.
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
:param order: Order of the constellation, i.e. the number of messages
|
2019-12-15 04:04:35 +00:00
|
|
|
|
that are to be transmitted, or equivalently the number of symbols whose
|
2019-12-13 17:11:09 +00:00
|
|
|
|
placements in the constellation have to be learned.
|
|
|
|
|
:param encoder_layers_sizes: Shape of the encoder’s hidden layers. The
|
|
|
|
|
size of this sequence is the number of hidden layers, with each element
|
|
|
|
|
being a number which specifies the number of neurons in its channel.
|
|
|
|
|
:param decoder_layers_sizes: Shape of the decoder’s hidden layers. Uses
|
|
|
|
|
the same convention as `encoder_layers_sizes` above.
|
2019-12-15 04:04:35 +00:00
|
|
|
|
:param channel_model: Instance of the channel model to use between the
|
|
|
|
|
encoder and decoder network.
|
2019-12-13 17:11:09 +00:00
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
2019-12-16 00:42:50 +00:00
|
|
|
|
self.order = order
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
# Build the encoder network taking a one-hot encoded message as input
|
|
|
|
|
# and outputting an I/Q vector. The network additionally uses hidden
|
|
|
|
|
# layers as specified in `encoder_layers_sizes`
|
|
|
|
|
prev_layer_size = order
|
|
|
|
|
encoder_layers = []
|
|
|
|
|
|
|
|
|
|
for layer_size in encoder_layers_sizes:
|
|
|
|
|
encoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
|
|
|
|
prev_layer_size = layer_size
|
|
|
|
|
|
2019-12-15 14:42:33 +00:00
|
|
|
|
encoder_layers += [
|
|
|
|
|
nn.Linear(prev_layer_size, 2),
|
|
|
|
|
NormalizePower(),
|
|
|
|
|
]
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential(*encoder_layers)
|
2019-12-15 04:04:35 +00:00
|
|
|
|
self.channel = channel_model
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
# Build the decoder network taking the noisy I/Q vector received from
|
|
|
|
|
# the channel as input and outputting a probability vector for each
|
|
|
|
|
# original message. The network additionally uses hidden layers as
|
|
|
|
|
# specified in `decoder_layers_sizes`
|
|
|
|
|
prev_layer_size = 2
|
|
|
|
|
decoder_layers = []
|
|
|
|
|
|
|
|
|
|
for layer_size in decoder_layers_sizes:
|
|
|
|
|
decoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
|
|
|
|
prev_layer_size = layer_size
|
|
|
|
|
|
|
|
|
|
# Softmax is not used at the end of the network because the
|
|
|
|
|
# CrossEntropyLoss criterion is used for training, which includes
|
|
|
|
|
# LogSoftmax
|
2019-12-15 04:04:35 +00:00
|
|
|
|
decoder_layers.append(nn.Linear(prev_layer_size, order))
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
|
|
self.decoder = nn.Sequential(*decoder_layers)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
"""
|
|
|
|
|
Perform encoding and decoding of an input vector and compute its
|
|
|
|
|
reconstructed vector.
|
|
|
|
|
|
|
|
|
|
:param x: Original one-hot encoded data.
|
|
|
|
|
:return: Reconstructed vector.
|
|
|
|
|
"""
|
|
|
|
|
symbol = self.encoder(x)
|
|
|
|
|
noisy_symbol = self.channel(symbol)
|
|
|
|
|
return self.decoder(noisy_symbol)
|
2019-12-16 00:42:50 +00:00
|
|
|
|
|
|
|
|
|
def get_constellation(self):
|
|
|
|
|
"""
|
|
|
|
|
Extract symbol constellation out of the trained encoder.
|
|
|
|
|
|
|
|
|
|
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
|
|
|
|
|
vector that is the result of encoding the nᵗʰ message.
|
|
|
|
|
"""
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
return self.encoder(
|
|
|
|
|
util.messages_to_onehot(
|
|
|
|
|
torch.arange(0, self.order),
|
|
|
|
|
self.order
|
|
|
|
|
)
|
|
|
|
|
)
|