Use best parameters as found by experimentation
This commit is contained in:
		
							parent
							
								
									5cb087d971
								
							
						
					
					
						commit
						0769a61fcf
					
				
							
								
								
									
										13
									
								
								train.py
								
								
								
								
							
							
						
						
									
										13
									
								
								train.py
								
								
								
								
							|  | @ -4,16 +4,19 @@ import torch | ||||||
| from matplotlib import pyplot | from matplotlib import pyplot | ||||||
| from mpl_toolkits.axisartist.axislines import SubplotZero | from mpl_toolkits.axisartist.axislines import SubplotZero | ||||||
| 
 | 
 | ||||||
| torch.manual_seed(42) | torch.manual_seed(57) | ||||||
| 
 | 
 | ||||||
| # Number of symbols to learn | # Number of symbols to learn | ||||||
| order = 16 | order = 16 | ||||||
| 
 | 
 | ||||||
|  | # Initial value for the learning rate | ||||||
|  | initial_learning_rate = 0.1 | ||||||
|  | 
 | ||||||
| # Number of batches to skip between every loss report | # Number of batches to skip between every loss report | ||||||
| loss_report_batch_skip = 50 | loss_report_batch_skip = 50 | ||||||
| 
 | 
 | ||||||
| # Size of batches | # Size of batches | ||||||
| batch_size = 32 | batch_size = 2048 | ||||||
| 
 | 
 | ||||||
| # File in which the trained model is saved | # File in which the trained model is saved | ||||||
| output_file = 'output/constellation-order-{}.pth'.format(order) | output_file = 'output/constellation-order-{}.pth'.format(order) | ||||||
|  | @ -30,8 +33,8 @@ pyplot.show(block=False) | ||||||
| # Train the model with random data | # Train the model with random data | ||||||
| model = constellation.ConstellationNet( | model = constellation.ConstellationNet( | ||||||
|     order=order, |     order=order, | ||||||
|     encoder_layers_sizes=(8, 4), |     encoder_layers_sizes=(8, 4,), | ||||||
|     decoder_layers_sizes=(4, 8), |     decoder_layers_sizes=(4, 8,), | ||||||
|     channel_model=constellation.GaussianChannel() |     channel_model=constellation.GaussianChannel() | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -52,7 +55,7 @@ total_change = float('inf') | ||||||
| 
 | 
 | ||||||
| # Optimizer settings | # Optimizer settings | ||||||
| criterion = torch.nn.CrossEntropyLoss() | criterion = torch.nn.CrossEntropyLoss() | ||||||
| optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate) | ||||||
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||||||
|     optimizer, verbose=True, |     optimizer, verbose=True, | ||||||
|     factor=0.25, |     factor=0.25, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue