constellationnet/constellation/ConstellationNet.py

100 lines
3.6 KiB
Python
Raw Permalink Normal View History

2019-12-13 17:11:09 +00:00
import torch.nn as nn
import torch
2019-12-15 04:04:35 +00:00
from .GaussianChannel import GaussianChannel
2019-12-15 14:42:33 +00:00
from .NormalizePower import NormalizePower
from . import util
2019-12-13 17:11:09 +00:00
class ConstellationNet(nn.Module):
2019-12-15 14:42:33 +00:00
"""
Autoencoder network to automatically shape a constellation of symbols for
efficient communication over an optical fiber channel.
"""
2019-12-13 17:11:09 +00:00
def __init__(
self,
order=2,
encoder_layers=(),
decoder_layers=(),
2019-12-13 17:11:09 +00:00
):
"""
2019-12-15 14:42:33 +00:00
Create an autoencoder.
2019-12-13 17:11:09 +00:00
:param order: Order of the constellation, i.e. the number of messages
2019-12-15 04:04:35 +00:00
that are to be transmitted, or equivalently the number of symbols whose
2019-12-13 17:11:09 +00:00
placements in the constellation have to be learned.
:param encoder_layers: Shape of the encoders hidden layers. The
2019-12-13 17:11:09 +00:00
size of this sequence is the number of hidden layers, with each element
being a number which specifies the number of neurons in its channel.
:param decoder_layers: Shape of the decoders hidden layers. Uses
2019-12-13 17:11:09 +00:00
the same convention as `encoder_layers_sizes` above.
"""
super().__init__()
self.order = order
2019-12-13 17:11:09 +00:00
# Build the encoder network taking a one-hot encoded message as input
# and outputting an I/Q vector. The network additionally uses hidden
# layers as specified in `encoder_layers`
2019-12-13 17:11:09 +00:00
prev_layer_size = order
encoder_layers_list = []
2019-12-13 17:11:09 +00:00
for layer_size in encoder_layers:
encoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers_list.append(nn.ReLU())
encoder_layers_list.append(nn.BatchNorm1d(layer_size))
2019-12-13 17:11:09 +00:00
prev_layer_size = layer_size
encoder_layers_list += [
2019-12-15 14:42:33 +00:00
nn.Linear(prev_layer_size, 2),
NormalizePower(),
]
2019-12-13 17:11:09 +00:00
self.encoder = nn.Sequential(*encoder_layers_list)
self.channel = GaussianChannel()
2019-12-13 17:11:09 +00:00
# Build the decoder network taking the noisy I/Q vector received from
# the channel as input and outputting a probability vector for each
# original message. The network additionally uses hidden layers as
# specified in `decoder_layers`
2019-12-13 17:11:09 +00:00
prev_layer_size = 2
decoder_layers_list = []
2019-12-13 17:11:09 +00:00
for layer_size in decoder_layers:
decoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
encoder_layers_list.append(nn.ReLU())
decoder_layers_list.append(nn.BatchNorm1d(layer_size))
2019-12-13 17:11:09 +00:00
prev_layer_size = layer_size
# Softmax is not used at the end of the network because the
# CrossEntropyLoss criterion is used for training, which includes
# LogSoftmax
decoder_layers_list.append(nn.Linear(prev_layer_size, order))
2019-12-13 17:11:09 +00:00
self.decoder = nn.Sequential(*decoder_layers_list)
2019-12-13 17:11:09 +00:00
def forward(self, x):
"""
Perform encoding and decoding of an input vector and compute its
reconstructed vector.
:param x: Original one-hot encoded data.
:return: Reconstructed vector.
"""
symbol = self.encoder(x)
noisy_symbol = self.channel(symbol)
return self.decoder(noisy_symbol)
def get_constellation(self):
"""
Extract the symbol constellation out of the trained encoder.
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
vector that is the result of encoding the nᵗʰ message.
"""
with torch.no_grad():
return self.encoder(
util.messages_to_onehot(
torch.arange(0, self.order),
self.order
)
)