import constellation from constellation import util import torch from matplotlib import pyplot from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols order = 4 # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) # Restore model from file model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(4,), decoder_layers_sizes=(4,), channel_model=constellation.GaussianChannel() ) model.load_state_dict(torch.load(input_file)) model.eval() # Compute encoded vectors with torch.no_grad(): encoded_vectors = model.encoder( util.messages_to_onehot( torch.arange(0, order), order ) ) fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) # Extend axes symmetrically around zero so that they fit data extent = max( abs(encoded_vectors.min()), abs(encoded_vectors.max()) ) * 1.05 ax.set_xlim(-extent, extent) ax.set_ylim(-extent, extent) # Hide borders for direction in ['left', 'bottom', 'right', 'top']: ax.axis[direction].set_visible(False) # Show zero-centered axes for direction in ['xzero', 'yzero']: axis = ax.axis[direction] axis.set_visible(True) axis.set_axisline_style("-|>") # Configure axes ticks and labels ax.annotate( 'I', (1, 0.5), xycoords='axes fraction', xytext=(25, 0), textcoords='offset points', va='center', ha='right' ) ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') ax.axis['xzero'].major_ticklabels.set_ha('center') ax.axis['xzero'].major_ticklabels.set_va('top') ax.annotate( 'Q', (0.5, 1), xycoords='axes fraction', xytext=(0, 25), textcoords='offset points', va='center', ha='center' ) ax.axis['yzero'].major_ticklabels.set_rotation(-90) ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') ax.axis['yzero'].major_ticklabels.set_ha('left') ax.axis['yzero'].major_ticklabels.set_va('center') # Add a single tick on 0 ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0]) ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0]) ax.annotate( '0', (0, 0), xytext=(15, -10), textcoords='offset points', va='center', ha='center' ) ax.grid() # Plot encoded vectors ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10) # Add index label for each vector for row in range(order): ax.annotate( row + 1, encoded_vectors[row], xytext=(5, 5), textcoords='offset points', backgroundcolor='white', zorder=9 ) pyplot.show()