Plot decision regions from the decoder
This commit is contained in:
		
							parent
							
								
									3b40e27070
								
							
						
					
					
						commit
						8f6363ee21
					
				
							
								
								
									
										77
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										77
									
								
								plot.py
								
								
								
								
							|  | @ -2,6 +2,7 @@ import constellation | |||
| from constellation import util | ||||
| import torch | ||||
| from matplotlib import pyplot | ||||
| import matplotlib | ||||
| from mpl_toolkits.axisartist.axislines import SubplotZero | ||||
| 
 | ||||
| # Number learned symbols | ||||
|  | @ -10,36 +11,16 @@ order = 4 | |||
| # File in which the trained model is saved | ||||
| input_file = 'output/constellation-order-{}.pth'.format(order) | ||||
| 
 | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(4,), | ||||
|     decoder_layers_sizes=(4,), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
| model.eval() | ||||
| 
 | ||||
| # Compute encoded vectors | ||||
| with torch.no_grad(): | ||||
|     encoded_vectors = model.encoder( | ||||
|         util.messages_to_onehot( | ||||
|             torch.arange(0, order), | ||||
|             order | ||||
|         ) | ||||
|     ) | ||||
| # Color map used for decision regions | ||||
| color_map = matplotlib.cm.Set1 | ||||
| 
 | ||||
| # Setup plot | ||||
| fig = pyplot.figure() | ||||
| ax = SubplotZero(fig, 111) | ||||
| fig.add_subplot(ax) | ||||
| 
 | ||||
| # Extend axes symmetrically around zero so that they fit data | ||||
| extent = max( | ||||
|     abs(encoded_vectors.min()), | ||||
|     abs(encoded_vectors.max()) | ||||
| ) * 1.05 | ||||
| 
 | ||||
| extent = 1.5 | ||||
| ax.set_xlim(-extent, extent) | ||||
| ax.set_ylim(-extent, extent) | ||||
| 
 | ||||
|  | @ -87,10 +68,52 @@ ax.annotate( | |||
| 
 | ||||
| ax.grid() | ||||
| 
 | ||||
| # Plot encoded vectors | ||||
| ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10) | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|     order=order, | ||||
|     encoder_layers_sizes=(4,), | ||||
|     decoder_layers_sizes=(4,), | ||||
|     channel_model=constellation.GaussianChannel() | ||||
| ) | ||||
| 
 | ||||
| model.load_state_dict(torch.load(input_file)) | ||||
| model.eval() | ||||
| 
 | ||||
| # Plot decision regions | ||||
| color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) | ||||
| 
 | ||||
| step = 0.01 | ||||
| grid_range = torch.arange(-extent, extent, step) | ||||
| grid_y, grid_x = torch.meshgrid(grid_range, grid_range) | ||||
| grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) | ||||
| 
 | ||||
| ax.imshow( | ||||
|     grid_images, extent=(-extent, extent, -extent, extent), | ||||
|     aspect="auto", | ||||
|     origin="lower", | ||||
|     cmap=color_map, | ||||
|     norm=color_norm, | ||||
|     alpha=0.15 | ||||
| ) | ||||
| 
 | ||||
| # Plot encoded vectors | ||||
| with torch.no_grad(): | ||||
|     encoded_vectors = model.encoder( | ||||
|         util.messages_to_onehot( | ||||
|             torch.arange(0, order), | ||||
|             order | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
| ax.scatter( | ||||
|     *zip(*encoded_vectors.tolist()), | ||||
|     zorder=10, | ||||
|     c=range(len(encoded_vectors)), | ||||
|     edgecolor='black', | ||||
|     cmap=color_map, | ||||
|     norm=color_norm, | ||||
| ) | ||||
| 
 | ||||
| # Add index label for each vector | ||||
| for row in range(order): | ||||
|     ax.annotate( | ||||
|         row + 1, encoded_vectors[row], | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue