Simplify arguments for ConstellationNet
This commit is contained in:
parent
a7e9dd2230
commit
d8a140d793
|
@ -13,9 +13,8 @@ class ConstellationNet(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
order=2,
|
||||
encoder_layers_sizes=(),
|
||||
decoder_layers_sizes=(),
|
||||
channel_model=GaussianChannel()
|
||||
encoder_layers=(),
|
||||
decoder_layers=(),
|
||||
):
|
||||
"""
|
||||
Create an autoencoder.
|
||||
|
@ -23,56 +22,54 @@ class ConstellationNet(nn.Module):
|
|||
:param order: Order of the constellation, i.e. the number of messages
|
||||
that are to be transmitted, or equivalently the number of symbols whose
|
||||
placements in the constellation have to be learned.
|
||||
:param encoder_layers_sizes: Shape of the encoder’s hidden layers. The
|
||||
:param encoder_layers: Shape of the encoder’s hidden layers. The
|
||||
size of this sequence is the number of hidden layers, with each element
|
||||
being a number which specifies the number of neurons in its channel.
|
||||
:param decoder_layers_sizes: Shape of the decoder’s hidden layers. Uses
|
||||
:param decoder_layers: Shape of the decoder’s hidden layers. Uses
|
||||
the same convention as `encoder_layers_sizes` above.
|
||||
:param channel_model: Instance of the channel model to use between the
|
||||
encoder and decoder network.
|
||||
"""
|
||||
super().__init__()
|
||||
self.order = order
|
||||
|
||||
# Build the encoder network taking a one-hot encoded message as input
|
||||
# and outputting an I/Q vector. The network additionally uses hidden
|
||||
# layers as specified in `encoder_layers_sizes`
|
||||
# layers as specified in `encoder_layers`
|
||||
prev_layer_size = order
|
||||
encoder_layers = []
|
||||
encoder_layers_list = []
|
||||
|
||||
for layer_size in encoder_layers_sizes:
|
||||
encoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
||||
encoder_layers.append(nn.ReLU())
|
||||
encoder_layers.append(nn.BatchNorm1d(layer_size))
|
||||
for layer_size in encoder_layers:
|
||||
encoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
|
||||
encoder_layers_list.append(nn.ReLU())
|
||||
encoder_layers_list.append(nn.BatchNorm1d(layer_size))
|
||||
prev_layer_size = layer_size
|
||||
|
||||
encoder_layers += [
|
||||
encoder_layers_list += [
|
||||
nn.Linear(prev_layer_size, 2),
|
||||
NormalizePower(),
|
||||
]
|
||||
|
||||
self.encoder = nn.Sequential(*encoder_layers)
|
||||
self.channel = channel_model
|
||||
self.encoder = nn.Sequential(*encoder_layers_list)
|
||||
self.channel = GaussianChannel()
|
||||
|
||||
# Build the decoder network taking the noisy I/Q vector received from
|
||||
# the channel as input and outputting a probability vector for each
|
||||
# original message. The network additionally uses hidden layers as
|
||||
# specified in `decoder_layers_sizes`
|
||||
# specified in `decoder_layers`
|
||||
prev_layer_size = 2
|
||||
decoder_layers = []
|
||||
decoder_layers_list = []
|
||||
|
||||
for layer_size in decoder_layers_sizes:
|
||||
decoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
||||
encoder_layers.append(nn.ReLU())
|
||||
decoder_layers.append(nn.BatchNorm1d(layer_size))
|
||||
for layer_size in decoder_layers:
|
||||
decoder_layers_list.append(nn.Linear(prev_layer_size, layer_size))
|
||||
encoder_layers_list.append(nn.ReLU())
|
||||
decoder_layers_list.append(nn.BatchNorm1d(layer_size))
|
||||
prev_layer_size = layer_size
|
||||
|
||||
# Softmax is not used at the end of the network because the
|
||||
# CrossEntropyLoss criterion is used for training, which includes
|
||||
# LogSoftmax
|
||||
decoder_layers.append(nn.Linear(prev_layer_size, order))
|
||||
decoder_layers_list.append(nn.Linear(prev_layer_size, order))
|
||||
|
||||
self.decoder = nn.Sequential(*decoder_layers)
|
||||
self.decoder = nn.Sequential(*decoder_layers_list)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -88,7 +85,7 @@ class ConstellationNet(nn.Module):
|
|||
|
||||
def get_constellation(self):
|
||||
"""
|
||||
Extract symbol constellation out of the trained encoder.
|
||||
Extract the symbol constellation out of the trained encoder.
|
||||
|
||||
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
|
||||
vector that is the result of encoding the nᵗʰ message.
|
||||
|
|
|
@ -4,6 +4,10 @@ import math
|
|||
|
||||
|
||||
class GaussianChannel(nn.Module):
|
||||
"""
|
||||
Simulated communication channel that assumes a Gaussian noise model for
|
||||
taking in account interference.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
Loading…
Reference in New Issue