import constellation from constellation import util import torch from matplotlib import pyplot import matplotlib from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols order = 4 # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) # Color map used for decision regions color_map = matplotlib.cm.Dark2 # Restore model from file model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(4,), decoder_layers_sizes=(4,), channel_model=constellation.GaussianChannel() ) model.load_state_dict(torch.load(input_file)) model.eval() # Extract encoding with torch.no_grad(): encoded_vectors = model.encoder( util.messages_to_onehot( torch.arange(0, order), order ) ) # Setup plot fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) # Extend axes symmetrically around zero so that they fit data axis_extent = max( abs(encoded_vectors.min()), abs(encoded_vectors.max()) ) * 1.05 ax.set_xlim(-axis_extent, axis_extent) ax.set_ylim(-axis_extent, axis_extent) # Hide borders but keep ticks for direction in ['left', 'bottom', 'right', 'top']: ax.axis[direction].line.set_color('#00000000') # Show zero-centered axes without ticks for direction in ['xzero', 'yzero']: axis = ax.axis[direction] axis.set_visible(True) axis.set_axisline_style('-|>') axis.major_ticklabels.set_visible(False) # Add axis names ax.annotate( 'I', (1, 0.5), xycoords='axes fraction', xytext=(25, 0), textcoords='offset points', va='center', ha='right' ) ax.annotate( 'Q', (0.5, 1), xycoords='axes fraction', xytext=(0, 25), textcoords='offset points', va='center', ha='center' ) ax.grid() # Plot decision regions color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) regions_extent = 2 * axis_extent step = 0.001 * regions_extent grid_range = torch.arange(-regions_extent, regions_extent, step) grid_y, grid_x = torch.meshgrid(grid_range, grid_range) grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) ax.imshow( grid_images, extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), aspect='auto', origin='lower', cmap=color_map, norm=color_norm, alpha=0.1 ) # Plot encoded vectors ax.scatter( *zip(*encoded_vectors.tolist()), zorder=10, s=60, c=range(len(encoded_vectors)), edgecolor='black', cmap=color_map, norm=color_norm, ) # Plot noise noisy_count = 1000 noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) ax.scatter( *zip(*noisy_vectors.tolist()), marker='.', s=5, c=list(range(len(encoded_vectors))) * noisy_count, cmap=color_map, norm=color_norm, alpha=0.7, zorder=8 ) pyplot.show()