constellationnet/plot.py

41 lines
908 B
Python

import constellation
from constellation import util
import torch
from matplotlib import pyplot
import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero
# Number learned symbols
order = 4
# File in which the trained model is saved
input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions
color_map = matplotlib.cm.Dark2
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(8,),
decoder_layers_sizes=(8,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Setup plot
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
constellation = model.get_constellation()
util.plot_constellation(
ax, constellation,
model.channel, model.decoder,
grid_step=0.001, noise_samples=2500
)
pyplot.show()