122 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			122 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| import constellation
 | |
| from constellation import util
 | |
| import torch
 | |
| from matplotlib import pyplot
 | |
| import matplotlib
 | |
| from mpl_toolkits.axisartist.axislines import SubplotZero
 | |
| 
 | |
| # Number learned symbols
 | |
| order = 4
 | |
| 
 | |
| # File in which the trained model is saved
 | |
| input_file = 'output/constellation-order-{}.pth'.format(order)
 | |
| 
 | |
| # Color map used for decision regions
 | |
| color_map = matplotlib.cm.Dark2
 | |
| 
 | |
| # Restore model from file
 | |
| model = constellation.ConstellationNet(
 | |
|     order=order,
 | |
|     encoder_layers_sizes=(4,),
 | |
|     decoder_layers_sizes=(4,),
 | |
|     channel_model=constellation.GaussianChannel()
 | |
| )
 | |
| 
 | |
| model.load_state_dict(torch.load(input_file))
 | |
| model.eval()
 | |
| 
 | |
| # Extract encoding
 | |
| with torch.no_grad():
 | |
|     encoded_vectors = model.encoder(
 | |
|         util.messages_to_onehot(
 | |
|             torch.arange(0, order),
 | |
|             order
 | |
|         )
 | |
|     )
 | |
| 
 | |
| # Setup plot
 | |
| fig = pyplot.figure()
 | |
| ax = SubplotZero(fig, 111)
 | |
| fig.add_subplot(ax)
 | |
| 
 | |
| # Extend axes symmetrically around zero so that they fit data
 | |
| axis_extent = max(
 | |
|     abs(encoded_vectors.min()),
 | |
|     abs(encoded_vectors.max())
 | |
| ) * 1.05
 | |
| ax.set_xlim(-axis_extent, axis_extent)
 | |
| ax.set_ylim(-axis_extent, axis_extent)
 | |
| 
 | |
| # Hide borders but keep ticks
 | |
| for direction in ['left', 'bottom', 'right', 'top']:
 | |
|     ax.axis[direction].line.set_color('#00000000')
 | |
| 
 | |
| # Show zero-centered axes without ticks
 | |
| for direction in ['xzero', 'yzero']:
 | |
|     axis = ax.axis[direction]
 | |
|     axis.set_visible(True)
 | |
|     axis.set_axisline_style('-|>')
 | |
|     axis.major_ticklabels.set_visible(False)
 | |
| 
 | |
| # Add axis names
 | |
| ax.annotate(
 | |
|     'I', (1, 0.5), xycoords='axes fraction',
 | |
|     xytext=(25, 0), textcoords='offset points',
 | |
|     va='center', ha='right'
 | |
| )
 | |
| 
 | |
| ax.annotate(
 | |
|     'Q', (0.5, 1), xycoords='axes fraction',
 | |
|     xytext=(0, 25), textcoords='offset points',
 | |
|     va='center', ha='center'
 | |
| )
 | |
| 
 | |
| ax.grid()
 | |
| 
 | |
| # Plot decision regions
 | |
| color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
 | |
| 
 | |
| regions_extent = 2 * axis_extent
 | |
| step = 0.001 * regions_extent
 | |
| grid_range = torch.arange(-regions_extent, regions_extent, step)
 | |
| grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
 | |
| grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
 | |
| 
 | |
| ax.imshow(
 | |
|     grid_images,
 | |
|     extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
 | |
|     aspect='auto',
 | |
|     origin='lower',
 | |
|     cmap=color_map,
 | |
|     norm=color_norm,
 | |
|     alpha=0.1
 | |
| )
 | |
| 
 | |
| # Plot encoded vectors
 | |
| ax.scatter(
 | |
|     *zip(*encoded_vectors.tolist()),
 | |
|     zorder=10,
 | |
|     s=60,
 | |
|     c=range(len(encoded_vectors)),
 | |
|     edgecolor='black',
 | |
|     cmap=color_map,
 | |
|     norm=color_norm,
 | |
| )
 | |
| 
 | |
| # Plot noise
 | |
| noisy_count = 1000
 | |
| noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
 | |
| 
 | |
| ax.scatter(
 | |
|     *zip(*noisy_vectors.tolist()),
 | |
|     marker='.',
 | |
|     s=5,
 | |
|     c=list(range(len(encoded_vectors))) * noisy_count,
 | |
|     cmap=color_map,
 | |
|     norm=color_norm,
 | |
|     alpha=0.7,
 | |
|     zorder=8
 | |
| )
 | |
| 
 | |
| pyplot.show()
 |