Add initial implementation
This commit is contained in:
parent
7463e9fe5b
commit
2af8354a07
|
@ -0,0 +1 @@
|
|||
__pycache__
|
|
@ -0,0 +1,78 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ConstellationNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
order=2,
|
||||
encoder_layers_sizes=(),
|
||||
decoder_layers_sizes=()
|
||||
):
|
||||
"""
|
||||
Create an encoder-decoder network to automatically shape a
|
||||
constellation of symbols for efficient communication over an optical
|
||||
fiber channel.
|
||||
|
||||
:param order: Order of the constellation, i.e. the number of messages
|
||||
that are to be transmitted or equivalently the number of symbols whose
|
||||
placements in the constellation have to be learned.
|
||||
:param encoder_layers_sizes: Shape of the encoder’s hidden layers. The
|
||||
size of this sequence is the number of hidden layers, with each element
|
||||
being a number which specifies the number of neurons in its channel.
|
||||
:param decoder_layers_sizes: Shape of the decoder’s hidden layers. Uses
|
||||
the same convention as `encoder_layers_sizes` above.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Build the encoder network taking a one-hot encoded message as input
|
||||
# and outputting an I/Q vector. The network additionally uses hidden
|
||||
# layers as specified in `encoder_layers_sizes`
|
||||
prev_layer_size = order
|
||||
encoder_layers = []
|
||||
|
||||
for layer_size in encoder_layers_sizes:
|
||||
encoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
||||
encoder_layers.append(nn.ReLU())
|
||||
prev_layer_size = layer_size
|
||||
|
||||
encoder_layers += [
|
||||
nn.Linear(prev_layer_size, 2),
|
||||
nn.ReLU(),
|
||||
# TODO: Normalization step
|
||||
]
|
||||
|
||||
self.encoder = nn.Sequential(*encoder_layers)
|
||||
|
||||
# TODO: Add real channel model
|
||||
self.channel = nn.Identity()
|
||||
|
||||
# Build the decoder network taking the noisy I/Q vector received from
|
||||
# the channel as input and outputting a probability vector for each
|
||||
# original message. The network additionally uses hidden layers as
|
||||
# specified in `decoder_layers_sizes`
|
||||
prev_layer_size = 2
|
||||
decoder_layers = []
|
||||
|
||||
for layer_size in decoder_layers_sizes:
|
||||
decoder_layers.append(nn.Linear(prev_layer_size, layer_size))
|
||||
decoder_layers.append(nn.ReLU())
|
||||
prev_layer_size = layer_size
|
||||
|
||||
# Softmax is not used at the end of the network because the
|
||||
# CrossEntropyLoss criterion is used for training, which includes
|
||||
# LogSoftmax
|
||||
decoder_layers.append(nn.Linear(prev_layer_size, order),)
|
||||
|
||||
self.decoder = nn.Sequential(*decoder_layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Perform encoding and decoding of an input vector and compute its
|
||||
reconstructed vector.
|
||||
|
||||
:param x: Original one-hot encoded data.
|
||||
:return: Reconstructed vector.
|
||||
"""
|
||||
symbol = self.encoder(x)
|
||||
noisy_symbol = self.channel(symbol)
|
||||
return self.decoder(noisy_symbol)
|
|
@ -0,0 +1,47 @@
|
|||
from ConstellationNet import ConstellationNet
|
||||
import torch
|
||||
|
||||
# Number of symbols to learn
|
||||
order = 4
|
||||
|
||||
# Number of training examples in an epoch
|
||||
epoch_size = 10000
|
||||
|
||||
# Number of epochs
|
||||
num_epochs = 25000
|
||||
|
||||
# Number of epochs to skip between every loss report
|
||||
loss_report_epoch_skip = 200
|
||||
|
||||
model = ConstellationNet(order=order)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
running_loss = 0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
classes_dataset = torch.randint(0, order, (epoch_size,))
|
||||
onehot_dataset = torch.nn.functional.one_hot(
|
||||
classes_dataset,
|
||||
num_classes=order
|
||||
).float()
|
||||
|
||||
optimizer.zero_grad()
|
||||
predictions = model(onehot_dataset)
|
||||
loss = criterion(predictions, classes_dataset)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Report loss
|
||||
running_loss += loss.item()
|
||||
|
||||
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
||||
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
||||
print('Loss is {}'.format(running_loss))
|
||||
running_loss = 0
|
||||
|
||||
# Test the model with class 1
|
||||
print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float()))
|
||||
|
||||
# Test the model with class 2
|
||||
print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float()))
|
Loading…
Reference in New Issue