import constellation from constellation import util import torch # Number of symbols to learn order = 4 # Number of epochs num_epochs = 10000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 500 # File in which the trained model is saved output_file = 'output/constellation-order-{}.pth'.format(order) model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(4,), decoder_layers_sizes=(4,), channel_model=constellation.GaussianChannel() ) # Train the model with random data model.train() print('Starting training with {} epochs\n'.format(num_epochs)) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 prev_epoch_size_multiple = 0 for epoch in range(num_epochs): epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048 if epoch_size_multiple != prev_epoch_size_multiple: classes_ordered = torch.arange(order).repeat(epoch_size_multiple) prev_epoch_size_multiple = epoch_size_multiple classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) optimizer.zero_grad() predictions = model(onehot_dataset) loss = criterion(predictions, classes_dataset) loss.backward() optimizer.step() # Report loss running_loss += loss.item() if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('Loss is {}'.format(running_loss / loss_report_epoch_skip)) running_loss = 0 print('\nFinished training\n') # Print some examples of reconstruction model.eval() print('Reconstruction examples:') print('Input vector\t\t\tOutput vector after softmax') with torch.no_grad(): onehot_example = util.messages_to_onehot(torch.arange(0, order), order) raw_output = model(onehot_example) raw_output.required_grad = False reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) for index in range(order): print('{}\t\t{}'.format( onehot_example[index].tolist(), '[{}]'.format(', '.join( '{:.5f}'.format(x) for x in reconstructed_example[index].tolist() )) )) print('\nSaving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)