diff --git a/plot.py b/plot.py index 93b74cd..123b616 100644 --- a/plot.py +++ b/plot.py @@ -6,7 +6,7 @@ import matplotlib from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols -order = 4 +order = 16 # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) @@ -17,8 +17,8 @@ color_map = matplotlib.cm.Dark2 # Restore model from file model = constellation.ConstellationNet( order=order, - encoder_layers_sizes=(8,), - decoder_layers_sizes=(8,), + encoder_layers_sizes=(8, 4), + decoder_layers_sizes=(4, 8), channel_model=constellation.GaussianChannel() ) @@ -34,7 +34,7 @@ constellation = model.get_constellation() util.plot_constellation( ax, constellation, model.channel, model.decoder, - grid_step=0.001, noise_samples=2500 + grid_step=0.001, noise_samples=0 ) pyplot.show() diff --git a/train.py b/train.py index 87fad5e..11e5ac7 100644 --- a/train.py +++ b/train.py @@ -7,16 +7,13 @@ from mpl_toolkits.axisartist.axislines import SubplotZero torch.manual_seed(42) # Number of symbols to learn -order = 4 +order = 16 # Number of batches to skip between every loss report -loss_report_batch_skip = 500 +loss_report_batch_skip = 50 -# Size of batches during coarse optimization (small batches) -coarse_batch_size = 8 - -# Size of batches during fine optimization (large batches) -fine_batch_size = 2048 +# Size of batches +batch_size = 32 # File in which the trained model is saved output_file = 'output/constellation-order-{}.pth'.format(order) @@ -33,15 +30,15 @@ pyplot.show(block=False) # Train the model with random data model = constellation.ConstellationNet( order=order, - encoder_layers_sizes=(8,), - decoder_layers_sizes=(8,), + encoder_layers_sizes=(8, 4), + decoder_layers_sizes=(4, 8), channel_model=constellation.GaussianChannel() ) print('Starting training\n') -criterion = torch.nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +# Current batch index +batch = 0 # Accumulated loss for last batches running_loss = 0 @@ -50,12 +47,6 @@ running_loss = 0 # the second phase where point positions are refined using large batches is_coarse_optim = True -# Current batch index -batch = 1 - -# Current batch size -batch_size = coarse_batch_size - # List of training examples (not shuffled) classes_ordered = torch.arange(order).repeat(batch_size) @@ -63,7 +54,18 @@ classes_ordered = torch.arange(order).repeat(batch_size) prev_constellation = model.get_constellation() total_change = float('inf') -while True: +# Optimizer settings +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.1) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, verbose=True, + factor=0.25, + patience=100, + cooldown=50, + threshold=1e-8 +) + +while total_change >= 1e-4: # Shuffle training data and convert to one-hot encoding classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) @@ -76,39 +78,35 @@ while True: loss.backward() optimizer.step() + # Update learning rate scheduler + scheduler.step(loss) + # Check for convergence model.eval() constellation = model.get_constellation() total_change = (constellation - prev_constellation).abs().sum() prev_constellation = constellation - if is_coarse_optim: - if total_change < 1e-5: - print('Changing to fine optimization') - is_coarse_optim = False - batch_size = fine_batch_size - classes_ordered = torch.arange(order).repeat(batch_size) - elif total_change < 1e-5: - break - - # Report loss and update figure (if applicable) + # Report loss running_loss += loss.item() if batch % loss_report_batch_skip == loss_report_batch_skip - 1: - print('Batch #{} (size {})'.format(batch + 1, batch_size)) + print('Batch #{}'.format(batch + 1)) print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) print('\tChange is {}\n'.format(total_change)) - ax.clear() - util.plot_constellation( - ax, constellation, - model.channel, model.decoder - ) - fig.canvas.draw() - pyplot.pause(1e-17) - running_loss = 0 + # Update figure with current encoding + ax.clear() + util.plot_constellation( + ax, constellation, + model.channel, model.decoder, + noise_samples=0 + ) + fig.canvas.draw() + fig.canvas.flush_events() + batch += 1 print('\nFinished training\n')