Change training strategy to adaptive learning rate
This commit is contained in:
		
							parent
							
								
									3f2c6d18a3
								
							
						
					
					
						commit
						8fa6b46ca8
					
				
							
								
								
									
										8
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										8
									
								
								plot.py
								
								
								
								
							|  | @ -6,7 +6,7 @@ import matplotlib | |||
| from mpl_toolkits.axisartist.axislines import SubplotZero | ||||
| 
 | ||||
| # Number learned symbols | ||||
| order = 4 | ||||
| order = 16 | ||||
| 
 | ||||
| # File in which the trained model is saved | ||||
| input_file = 'output/constellation-order-{}.pth'.format(order) | ||||
|  | @ -17,8 +17,8 @@ color_map = matplotlib.cm.Dark2 | |||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(8,), | ||||
|     decoder_layers_sizes=(8,), | ||||
|     encoder_layers_sizes=(8, 4), | ||||
|     decoder_layers_sizes=(4, 8), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
|  | @ -34,7 +34,7 @@ constellation = model.get_constellation() | |||
| util.plot_constellation( | ||||
|     ax, constellation, | ||||
|     model.channel, model.decoder, | ||||
|     grid_step=0.001, noise_samples=2500 | ||||
|     grid_step=0.001, noise_samples=0 | ||||
| ) | ||||
| 
 | ||||
| pyplot.show() | ||||
|  |  | |||
							
								
								
									
										64
									
								
								train.py
								
								
								
								
							
							
						
						
									
										64
									
								
								train.py
								
								
								
								
							|  | @ -7,16 +7,13 @@ from mpl_toolkits.axisartist.axislines import SubplotZero | |||
| torch.manual_seed(42) | ||||
| 
 | ||||
| # Number of symbols to learn | ||||
| order = 4 | ||||
| order = 16 | ||||
| 
 | ||||
| # Number of batches to skip between every loss report | ||||
| loss_report_batch_skip = 500 | ||||
| loss_report_batch_skip = 50 | ||||
| 
 | ||||
| # Size of batches during coarse optimization (small batches) | ||||
| coarse_batch_size = 8 | ||||
| 
 | ||||
| # Size of batches during fine optimization (large batches) | ||||
| fine_batch_size = 2048 | ||||
| # Size of batches | ||||
| batch_size = 32 | ||||
| 
 | ||||
| # File in which the trained model is saved | ||||
| output_file = 'output/constellation-order-{}.pth'.format(order) | ||||
|  | @ -33,15 +30,15 @@ pyplot.show(block=False) | |||
| # Train the model with random data | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(8,), | ||||
|     decoder_layers_sizes=(8,), | ||||
|     encoder_layers_sizes=(8, 4), | ||||
|     decoder_layers_sizes=(4, 8), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| print('Starting training\n') | ||||
| 
 | ||||
| criterion = torch.nn.CrossEntropyLoss() | ||||
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | ||||
| # Current batch index | ||||
| batch = 0 | ||||
| 
 | ||||
| # Accumulated loss for last batches | ||||
| running_loss = 0 | ||||
|  | @ -50,12 +47,6 @@ running_loss = 0 | |||
| # the second phase where point positions are refined using large batches | ||||
| is_coarse_optim = True | ||||
| 
 | ||||
| # Current batch index | ||||
| batch = 1 | ||||
| 
 | ||||
| # Current batch size | ||||
| batch_size = coarse_batch_size | ||||
| 
 | ||||
| # List of training examples (not shuffled) | ||||
| classes_ordered = torch.arange(order).repeat(batch_size) | ||||
| 
 | ||||
|  | @ -63,7 +54,18 @@ classes_ordered = torch.arange(order).repeat(batch_size) | |||
| prev_constellation = model.get_constellation() | ||||
| total_change = float('inf') | ||||
| 
 | ||||
| while True: | ||||
| # Optimizer settings | ||||
| criterion = torch.nn.CrossEntropyLoss() | ||||
| optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||||
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||||
|     optimizer, verbose=True, | ||||
|     factor=0.25, | ||||
|     patience=100, | ||||
|     cooldown=50, | ||||
|     threshold=1e-8 | ||||
| ) | ||||
| 
 | ||||
| while total_change >= 1e-4: | ||||
|     # Shuffle training data and convert to one-hot encoding | ||||
|     classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] | ||||
|     onehot_dataset = util.messages_to_onehot(classes_dataset, order) | ||||
|  | @ -76,38 +78,34 @@ while True: | |||
|     loss.backward() | ||||
|     optimizer.step() | ||||
| 
 | ||||
|     # Update learning rate scheduler | ||||
|     scheduler.step(loss) | ||||
| 
 | ||||
|     # Check for convergence | ||||
|     model.eval() | ||||
|     constellation = model.get_constellation() | ||||
|     total_change = (constellation - prev_constellation).abs().sum() | ||||
|     prev_constellation = constellation | ||||
| 
 | ||||
|     if is_coarse_optim: | ||||
|         if total_change < 1e-5: | ||||
|             print('Changing to fine optimization') | ||||
|             is_coarse_optim = False | ||||
|             batch_size = fine_batch_size | ||||
|             classes_ordered = torch.arange(order).repeat(batch_size) | ||||
|     elif total_change < 1e-5: | ||||
|         break | ||||
| 
 | ||||
|     # Report loss and update figure (if applicable) | ||||
|     # Report loss | ||||
|     running_loss += loss.item() | ||||
| 
 | ||||
|     if batch % loss_report_batch_skip == loss_report_batch_skip - 1: | ||||
|         print('Batch #{} (size {})'.format(batch + 1, batch_size)) | ||||
|         print('Batch #{}'.format(batch + 1)) | ||||
|         print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) | ||||
|         print('\tChange is {}\n'.format(total_change)) | ||||
| 
 | ||||
|         running_loss = 0 | ||||
| 
 | ||||
|     # Update figure with current encoding | ||||
|     ax.clear() | ||||
|     util.plot_constellation( | ||||
|         ax, constellation, | ||||
|             model.channel, model.decoder | ||||
|         model.channel, model.decoder, | ||||
|         noise_samples=0 | ||||
|     ) | ||||
|     fig.canvas.draw() | ||||
|         pyplot.pause(1e-17) | ||||
| 
 | ||||
|         running_loss = 0 | ||||
|     fig.canvas.flush_events() | ||||
| 
 | ||||
|     batch += 1 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue