import constellation from constellation import util import torch # Number of symbols to learn order = 4 # Number of training examples in an epoch epoch_size_multiple = 8 # Number of epochs num_epochs = 5000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 500 # File in which the trained model is saved output_file = 'output/constellation-order-{}.pth'.format(order) model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(4,), decoder_layers_sizes=(4,), channel_model=constellation.GaussianChannel() ) # Train the model with random data model.train() print('Starting training with {} epochs\n'.format(num_epochs)) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 classes_ordered = torch.arange(order).repeat(epoch_size_multiple) for epoch in range(num_epochs): classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) optimizer.zero_grad() predictions = model(onehot_dataset) loss = criterion(predictions, classes_dataset) loss.backward() optimizer.step() # Report loss running_loss += loss.item() if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('Loss is {}'.format(running_loss)) running_loss = 0 print('\nFinished training\n') # Print some examples of reconstruction model.eval() print('Reconstruction examples:') print('Input vector\t\t\tOutput vector after softmax') with torch.no_grad(): onehot_example = util.messages_to_onehot(torch.arange(0, order), order) raw_output = model(onehot_example) raw_output.required_grad = False reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) for index in range(order): print('{}\t\t{}'.format( onehot_example[index].tolist(), '[{}]'.format(', '.join( '{:.5f}'.format(x) for x in reconstructed_example[index].tolist() )) )) print('\nSaving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)