Fix plot scaling
This commit is contained in:
		
							parent
							
								
									8f6363ee21
								
							
						
					
					
						commit
						197a01e993
					
				
							
								
								
									
										46
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										46
									
								
								plot.py
								
								
								
								
							|  | @ -14,13 +14,36 @@ input_file = 'output/constellation-order-{}.pth'.format(order) | |||
| # Color map used for decision regions | ||||
| color_map = matplotlib.cm.Set1 | ||||
| 
 | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(4,), | ||||
|     decoder_layers_sizes=(4,), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
| model.eval() | ||||
| 
 | ||||
| # Extract encoding | ||||
| with torch.no_grad(): | ||||
|     encoded_vectors = model.encoder( | ||||
|         util.messages_to_onehot( | ||||
|             torch.arange(0, order), | ||||
|             order | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
| # Setup plot | ||||
| fig = pyplot.figure() | ||||
| ax = SubplotZero(fig, 111) | ||||
| fig.add_subplot(ax) | ||||
| 
 | ||||
| # Extend axes symmetrically around zero so that they fit data | ||||
| extent = 1.5 | ||||
| extent = max( | ||||
|     abs(encoded_vectors.min()), | ||||
|     abs(encoded_vectors.max()) | ||||
| ) * 1.05 | ||||
| ax.set_xlim(-extent, extent) | ||||
| ax.set_ylim(-extent, extent) | ||||
| 
 | ||||
|  | @ -68,21 +91,10 @@ ax.annotate( | |||
| 
 | ||||
| ax.grid() | ||||
| 
 | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(4,), | ||||
|     decoder_layers_sizes=(4,), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
| model.eval() | ||||
| 
 | ||||
| # Plot decision regions | ||||
| color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) | ||||
| 
 | ||||
| step = 0.01 | ||||
| step = 0.001 * extent | ||||
| grid_range = torch.arange(-extent, extent, step) | ||||
| grid_y, grid_x = torch.meshgrid(grid_range, grid_range) | ||||
| grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) | ||||
|  | @ -97,14 +109,6 @@ ax.imshow( | |||
| ) | ||||
| 
 | ||||
| # Plot encoded vectors | ||||
| with torch.no_grad(): | ||||
|     encoded_vectors = model.encoder( | ||||
|         util.messages_to_onehot( | ||||
|             torch.arange(0, order), | ||||
|             order | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
| ax.scatter( | ||||
|     *zip(*encoded_vectors.tolist()), | ||||
|     zorder=10, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue