| 
									
										
										
										
											2019-12-13 20:17:57 +00:00
										 |  |  | import constellation | 
					
						
							|  |  |  | from constellation import util | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | from matplotlib import pyplot | 
					
						
							|  |  |  | from mpl_toolkits.axisartist.axislines import SubplotZero | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 06:50:03 +00:00
										 |  |  | torch.manual_seed(42) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | # Number of symbols to learn | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  | order = 16 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | # Number of batches to skip between every loss report | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  | loss_report_batch_skip = 50 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  | # Size of batches | 
					
						
							|  |  |  | batch_size = 32 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-13 20:17:57 +00:00
										 |  |  | # File in which the trained model is saved | 
					
						
							| 
									
										
										
										
											2019-12-15 04:04:35 +00:00
										 |  |  | output_file = 'output/constellation-order-{}.pth'.format(order) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | ### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Setup plot for showing training progress | 
					
						
							|  |  |  | fig = pyplot.figure() | 
					
						
							|  |  |  | ax = SubplotZero(fig, 111) | 
					
						
							|  |  |  | fig.add_subplot(ax) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | pyplot.show(block=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Train the model with random data | 
					
						
							| 
									
										
										
										
											2019-12-15 04:04:35 +00:00
										 |  |  | model = constellation.ConstellationNet( | 
					
						
							|  |  |  |     order=order, | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  |     encoder_layers_sizes=(8, 4), | 
					
						
							|  |  |  |     decoder_layers_sizes=(4, 8), | 
					
						
							| 
									
										
										
										
											2019-12-15 04:04:35 +00:00
										 |  |  |     channel_model=constellation.GaussianChannel() | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2019-12-13 22:10:40 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | print('Starting training\n') | 
					
						
							| 
									
										
										
										
											2019-12-13 20:17:57 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  | # Current batch index | 
					
						
							|  |  |  | batch = 0 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | # Accumulated loss for last batches | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | running_loss = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  | # List of training examples (not shuffled) | 
					
						
							|  |  |  | classes_ordered = torch.arange(order).repeat(batch_size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Constellation from the previous training batch | 
					
						
							|  |  |  | prev_constellation = model.get_constellation() | 
					
						
							|  |  |  | total_change = float('inf') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  | # Optimizer settings | 
					
						
							|  |  |  | criterion = torch.nn.CrossEntropyLoss() | 
					
						
							|  |  |  | optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | 
					
						
							|  |  |  | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | 
					
						
							|  |  |  |     optimizer, verbose=True, | 
					
						
							|  |  |  |     factor=0.25, | 
					
						
							|  |  |  |     patience=100, | 
					
						
							|  |  |  |     cooldown=50, | 
					
						
							|  |  |  |     threshold=1e-8 | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | while total_change >= 1e-4: | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     # Shuffle training data and convert to one-hot encoding | 
					
						
							| 
									
										
										
										
											2019-12-15 05:59:10 +00:00
										 |  |  |     classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] | 
					
						
							| 
									
										
										
										
											2019-12-13 20:17:57 +00:00
										 |  |  |     onehot_dataset = util.messages_to_onehot(classes_dataset, order) | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     # Perform training step for current batch | 
					
						
							|  |  |  |     model.train() | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  |     optimizer.zero_grad() | 
					
						
							|  |  |  |     predictions = model(onehot_dataset) | 
					
						
							|  |  |  |     loss = criterion(predictions, classes_dataset) | 
					
						
							|  |  |  |     loss.backward() | 
					
						
							|  |  |  |     optimizer.step() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  |     # Update learning rate scheduler | 
					
						
							|  |  |  |     scheduler.step(loss) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     # Check for convergence | 
					
						
							|  |  |  |     model.eval() | 
					
						
							|  |  |  |     constellation = model.get_constellation() | 
					
						
							| 
									
										
										
										
											2019-12-16 14:55:06 +00:00
										 |  |  |     total_change = (constellation - prev_constellation).norm(dim=1).sum() | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     prev_constellation = constellation | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  |     # Report loss | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  |     running_loss += loss.item() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     if batch % loss_report_batch_skip == loss_report_batch_skip - 1: | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  |         print('Batch #{}'.format(batch + 1)) | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |         print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) | 
					
						
							|  |  |  |         print('\tChange is {}\n'.format(total_change)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-13 17:11:09 +00:00
										 |  |  |         running_loss = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 07:30:05 +00:00
										 |  |  |     # Update figure with current encoding | 
					
						
							|  |  |  |     ax.clear() | 
					
						
							|  |  |  |     util.plot_constellation( | 
					
						
							|  |  |  |         ax, constellation, | 
					
						
							|  |  |  |         model.channel, model.decoder, | 
					
						
							|  |  |  |         noise_samples=0 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     fig.canvas.draw() | 
					
						
							|  |  |  |     fig.canvas.flush_events() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 00:42:50 +00:00
										 |  |  |     batch += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 15:20:01 +00:00
										 |  |  | model.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Calcul de la perte finale | 
					
						
							|  |  |  | with torch.no_grad(): | 
					
						
							|  |  |  |     classes_ordered = torch.arange(order).repeat(2048) | 
					
						
							|  |  |  |     classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] | 
					
						
							|  |  |  |     onehot_dataset = util.messages_to_onehot(classes_dataset, order) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     predictions = model(onehot_dataset) | 
					
						
							|  |  |  |     final_loss = criterion(predictions, classes_dataset) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-16 14:17:15 +00:00
										 |  |  | print('\nFinished training') | 
					
						
							| 
									
										
										
										
											2019-12-16 15:20:01 +00:00
										 |  |  | print('Final loss is {}'.format(final_loss)) | 
					
						
							| 
									
										
										
										
											2019-12-16 14:17:15 +00:00
										 |  |  | print('Saving model as {}'.format(output_file)) | 
					
						
							| 
									
										
										
										
											2019-12-13 20:17:57 +00:00
										 |  |  | torch.save(model.state_dict(), output_file) |