79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import constellation
|
|
from constellation import util
|
|
import torch
|
|
|
|
# Number of symbols to learn
|
|
order = 4
|
|
|
|
# Number of training examples in an epoch
|
|
epoch_size_multiple = 8
|
|
|
|
# Number of epochs
|
|
num_epochs = 5000
|
|
|
|
# Number of epochs to skip between every loss report
|
|
loss_report_epoch_skip = 500
|
|
|
|
# File in which the trained model is saved
|
|
output_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
model = constellation.ConstellationNet(
|
|
order=order,
|
|
encoder_layers_sizes=(4,),
|
|
decoder_layers_sizes=(4,),
|
|
channel_model=constellation.GaussianChannel()
|
|
)
|
|
|
|
# Train the model with random data
|
|
model.train()
|
|
print('Starting training with {} epochs\n'.format(num_epochs))
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
running_loss = 0
|
|
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
|
|
|
for epoch in range(num_epochs):
|
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
|
|
|
optimizer.zero_grad()
|
|
predictions = model(onehot_dataset)
|
|
loss = criterion(predictions, classes_dataset)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Report loss
|
|
running_loss += loss.item()
|
|
|
|
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
|
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
|
print('Loss is {}'.format(running_loss))
|
|
running_loss = 0
|
|
|
|
print('\nFinished training\n')
|
|
|
|
# Print some examples of reconstruction
|
|
model.eval()
|
|
print('Reconstruction examples:')
|
|
print('Input vector\t\t\tOutput vector after softmax')
|
|
|
|
with torch.no_grad():
|
|
onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
|
|
raw_output = model(onehot_example)
|
|
raw_output.required_grad = False
|
|
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
|
|
|
|
for index in range(order):
|
|
print('{}\t\t{}'.format(
|
|
onehot_example[index].tolist(),
|
|
'[{}]'.format(', '.join(
|
|
'{:.5f}'.format(x)
|
|
for x in reconstructed_example[index].tolist()
|
|
))
|
|
))
|
|
|
|
print('\nSaving model as {}'.format(output_file))
|
|
torch.save(model.state_dict(), output_file)
|