Save whole model to avoid definition duplicate
This commit is contained in:
		
							parent
							
								
									0769a61fcf
								
							
						
					
					
						commit
						a7e9dd2230
					
				
							
								
								
									
										17
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										17
									
								
								plot.py
								
								
								
								
							|  | @ -8,21 +8,14 @@ from mpl_toolkits.axisartist.axislines import SubplotZero | |||
| # Number learned symbols | ||||
| order = 16 | ||||
| 
 | ||||
| # Color map used for decision regions and points | ||||
| color_map = matplotlib.cm.Dark2 | ||||
| 
 | ||||
| # File in which the trained model is saved | ||||
| input_file = 'output/constellation-order-{}.pth'.format(order) | ||||
| 
 | ||||
| # Color map used for decision regions | ||||
| color_map = matplotlib.cm.Dark2 | ||||
| 
 | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(8, 4), | ||||
|     decoder_layers_sizes=(4, 8), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
| model = torch.load(input_file) | ||||
| model.eval() | ||||
| 
 | ||||
| # Setup plot | ||||
|  | @ -34,7 +27,7 @@ constellation = model.get_constellation() | |||
| util.plot_constellation( | ||||
|     ax, constellation, | ||||
|     model.channel, model.decoder, | ||||
|     grid_step=0.001, noise_samples=5000 | ||||
|     grid_step=0.001, noise_samples=0 | ||||
| ) | ||||
| 
 | ||||
| pyplot.show() | ||||
|  |  | |||
							
								
								
									
										16
									
								
								train.py
								
								
								
								
							
							
						
						
									
										16
									
								
								train.py
								
								
								
								
							|  | @ -3,12 +3,16 @@ from constellation import util | |||
| import torch | ||||
| from matplotlib import pyplot | ||||
| from mpl_toolkits.axisartist.axislines import SubplotZero | ||||
| import warnings | ||||
| 
 | ||||
| torch.manual_seed(57) | ||||
| torch.manual_seed(42) | ||||
| 
 | ||||
| # Number of symbols to learn | ||||
| order = 16 | ||||
| 
 | ||||
| # Shape of the hidden layers | ||||
| hidden_layers = (8, 4,) | ||||
| 
 | ||||
| # Initial value for the learning rate | ||||
| initial_learning_rate = 0.1 | ||||
| 
 | ||||
|  | @ -33,9 +37,8 @@ pyplot.show(block=False) | |||
| # Train the model with random data | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(8, 4,), | ||||
|     decoder_layers_sizes=(4, 8,), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
|     encoder_layers=hidden_layers, | ||||
|     decoder_layers=hidden_layers[::-1], | ||||
| ) | ||||
| 
 | ||||
| print('Starting training\n') | ||||
|  | @ -122,4 +125,7 @@ with torch.no_grad(): | |||
| print('\nFinished training') | ||||
| print('Final loss is {}'.format(final_loss)) | ||||
| print('Saving model as {}'.format(output_file)) | ||||
| torch.save(model.state_dict(), output_file) | ||||
| 
 | ||||
| with warnings.catch_warnings(): | ||||
|     warnings.simplefilter('ignore') | ||||
|     torch.save(model, output_file) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue