import constellation from constellation import util import torch from matplotlib import pyplot import matplotlib from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols order = 16 # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) # Color map used for decision regions color_map = matplotlib.cm.Dark2 # Restore model from file model = constellation.ConstellationNet( order=order, encoder_layers_sizes=(8, 4), decoder_layers_sizes=(4, 8), channel_model=constellation.GaussianChannel() ) model.load_state_dict(torch.load(input_file)) model.eval() # Setup plot fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) constellation = model.get_constellation() util.plot_constellation( ax, constellation, model.channel, model.decoder, grid_step=0.001, noise_samples=0 ) pyplot.show()