Improve plot
This commit is contained in:
		
							parent
							
								
									49e63775dd
								
							
						
					
					
						commit
						a01d83f339
					
				
							
								
								
									
										77
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										77
									
								
								plot.py
								
								
								
								
							|  | @ -2,12 +2,15 @@ import constellation | |||
| from constellation import util | ||||
| import torch | ||||
| from matplotlib import pyplot | ||||
| from mpl_toolkits.axisartist.axislines import SubplotZero | ||||
| import math | ||||
| import numpy | ||||
| 
 | ||||
| # Number learned symbols | ||||
| order = 4 | ||||
| 
 | ||||
| # File in which the trained model is saved | ||||
| input_file = 'output/constellation-net.tc' | ||||
| input_file = 'output/constellation-order-{}.tc'.format(order) | ||||
| 
 | ||||
| model = constellation.ConstellationNet(order=order) | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
|  | @ -19,8 +22,74 @@ with torch.no_grad(): | |||
|             torch.arange(0, order), | ||||
|             order | ||||
|         ) | ||||
|     ).tolist() | ||||
|     ) | ||||
| 
 | ||||
| fig = pyplot.figure() | ||||
| ax = SubplotZero(fig, 111) | ||||
| fig.add_subplot(ax) | ||||
| 
 | ||||
| # Extend axes symmetrically around zero so that they fit data | ||||
| extent = max( | ||||
|     abs(encoded_vectors.min()), | ||||
|     abs(encoded_vectors.max()) | ||||
| ) * 1.05 | ||||
| 
 | ||||
| ax.set_xlim(-extent, extent) | ||||
| ax.set_ylim(-extent, extent) | ||||
| 
 | ||||
| # Hide borders | ||||
| for direction in ['left', 'bottom', 'right', 'top']: | ||||
|     ax.axis[direction].set_visible(False) | ||||
| 
 | ||||
| # Show zero-centered axes | ||||
| for direction in ['xzero', 'yzero']: | ||||
|     axis = ax.axis[direction] | ||||
|     axis.set_visible(True) | ||||
|     axis.set_axisline_style("-|>") | ||||
| 
 | ||||
| # Configure axes ticks and labels | ||||
| ax.annotate( | ||||
|     'I', (1, 0.5), xycoords='axes fraction', | ||||
|     xytext=(25, 0), textcoords='offset points', | ||||
|     va='center', ha='right' | ||||
| ) | ||||
| 
 | ||||
| ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') | ||||
| ax.axis['xzero'].major_ticklabels.set_ha('center') | ||||
| ax.axis['xzero'].major_ticklabels.set_va('top') | ||||
| 
 | ||||
| ax.annotate( | ||||
|     'Q', (0.5, 1), xycoords='axes fraction', | ||||
|     xytext=(0, 25), textcoords='offset points', | ||||
|     va='center', ha='center' | ||||
| ) | ||||
| 
 | ||||
| ax.axis['yzero'].major_ticklabels.set_rotation(-90) | ||||
| ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') | ||||
| ax.axis['yzero'].major_ticklabels.set_ha('left') | ||||
| ax.axis['yzero'].major_ticklabels.set_va('center') | ||||
| 
 | ||||
| # Add a single tick on 0 | ||||
| ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0]) | ||||
| ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0]) | ||||
| 
 | ||||
| ax.annotate( | ||||
|     '0', (0, 0), | ||||
|     xytext=(15, -10), textcoords='offset points', | ||||
|     va='center', ha='center' | ||||
| ) | ||||
| 
 | ||||
| ax.grid() | ||||
| 
 | ||||
| # Plot encoded vectors | ||||
| ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10) | ||||
| 
 | ||||
| # Add index label for each vector | ||||
| for row in range(order): | ||||
|     ax.annotate( | ||||
|         row + 1, encoded_vectors[row], | ||||
|         xytext=(5, 5), textcoords='offset points', | ||||
|         backgroundcolor='white', zorder=9 | ||||
|     ) | ||||
| 
 | ||||
| fig, axis = pyplot.subplots() | ||||
| axis.scatter(*zip(*encoded_vectors)) | ||||
| pyplot.show() | ||||
|  |  | |||
							
								
								
									
										27
									
								
								train.py
								
								
								
								
							
							
						
						
									
										27
									
								
								train.py
								
								
								
								
							|  | @ -15,11 +15,14 @@ num_epochs = 20000 | |||
| loss_report_epoch_skip = 500 | ||||
| 
 | ||||
| # File in which the trained model is saved | ||||
| output_file = 'output/constellation-net.tc' | ||||
| 
 | ||||
| print('Starting training with {} epochs\n'.format(num_epochs)) | ||||
| output_file = 'output/constellation-order-{}.tc'.format(order) | ||||
| 
 | ||||
| model = constellation.ConstellationNet(order=order) | ||||
| 
 | ||||
| # Train the model with random data | ||||
| model.train() | ||||
| print('Starting training with {} epochs\n'.format(num_epochs)) | ||||
| 
 | ||||
| criterion = torch.nn.CrossEntropyLoss() | ||||
| optimizer = torch.optim.Adam(model.parameters()) | ||||
| 
 | ||||
|  | @ -46,24 +49,22 @@ for epoch in range(num_epochs): | |||
| print('\nFinished training\n') | ||||
| 
 | ||||
| # Print some examples of reconstruction | ||||
| with torch.no_grad(): | ||||
|     num_examples = 5 | ||||
| model.eval() | ||||
| print('Reconstruction examples:') | ||||
| print('Input vector\t\t\tOutput vector after softmax') | ||||
| 
 | ||||
|     classes_example = util.get_random_messages(num_examples, order) | ||||
|     onehot_example = util.messages_to_onehot(classes_example, order) | ||||
| with torch.no_grad(): | ||||
|     onehot_example = util.messages_to_onehot(torch.arange(0, order), order) | ||||
|     raw_output = model(onehot_example) | ||||
|     raw_output.required_grad = False | ||||
|     reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) | ||||
| 
 | ||||
|     print('Reconstruction examples:') | ||||
|     print('Input vector\t\t\tOutput vector after softmax') | ||||
| 
 | ||||
|     for example_index in range(num_examples): | ||||
|     for index in range(order): | ||||
|         print('{}\t\t{}'.format( | ||||
|             onehot_example[example_index].tolist(), | ||||
|             onehot_example[index].tolist(), | ||||
|             '[{}]'.format(', '.join( | ||||
|                 '{:.5f}'.format(x) | ||||
|                 for x in reconstructed_example[example_index].tolist() | ||||
|                 for x in reconstructed_example[index].tolist() | ||||
|             )) | ||||
|         )) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue