import constellation from constellation import util import torch from matplotlib import pyplot from mpl_toolkits.axisartist.axislines import SubplotZero torch.manual_seed(42) # Number of symbols to learn order = 4 # Number of batches to skip between every loss report loss_report_batch_skip = 500 # Size of batches during coarse optimization (small batches) coarse_batch_size = 8 # Size of batches during fine optimization (large batches) fine_batch_size = 2048 # File in which the trained model is saved output_file = 'output/constellation-order-{}.pth'.format(order) ### # Setup plot for showing training progress fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) pyplot.show(block=False) # Train the model with random data model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(8,), decoder_layers_sizes=(8,), channel_model=constellation.GaussianChannel() ) print('Starting training\n') criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Accumulated loss for last batches running_loss = 0 # True in the first training phase where small batches are used, and false in # the second phase where point positions are refined using large batches is_coarse_optim = True # Current batch index batch = 1 # Current batch size batch_size = coarse_batch_size # List of training examples (not shuffled) classes_ordered = torch.arange(order).repeat(batch_size) # Constellation from the previous training batch prev_constellation = model.get_constellation() total_change = float('inf') while True: # Shuffle training data and convert to one-hot encoding classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) # Perform training step for current batch model.train() optimizer.zero_grad() predictions = model(onehot_dataset) loss = criterion(predictions, classes_dataset) loss.backward() optimizer.step() # Check for convergence model.eval() constellation = model.get_constellation() total_change = (constellation - prev_constellation).abs().sum() prev_constellation = constellation if is_coarse_optim: if total_change < 1e-5: print('Changing to fine optimization') is_coarse_optim = False batch_size = fine_batch_size classes_ordered = torch.arange(order).repeat(batch_size) elif total_change < 1e-5: break # Report loss and update figure (if applicable) running_loss += loss.item() if batch % loss_report_batch_skip == loss_report_batch_skip - 1: print('Batch #{} (size {})'.format(batch + 1, batch_size)) print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) print('\tChange is {}\n'.format(total_change)) ax.clear() util.plot_constellation( ax, constellation, model.channel, model.decoder ) fig.canvas.draw() pyplot.pause(1e-17) running_loss = 0 batch += 1 print('\nFinished training\n') # Print some examples of reconstruction model.eval() print('Reconstruction examples:') print('Input vector\t\t\tOutput vector after softmax') with torch.no_grad(): onehot_example = util.messages_to_onehot(torch.arange(0, order), order) raw_output = model(onehot_example) raw_output.required_grad = False reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) for index in range(order): print('{}\t\t{}'.format( onehot_example[index].tolist(), '[{}]'.format(', '.join( '{:.5f}'.format(x) for x in reconstructed_example[index].tolist() )) )) print('\nSaving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)