constellationnet/plot.py

34 lines
730 B
Python
Raw Normal View History

2019-12-13 20:17:57 +00:00
import constellation
from constellation import util
import torch
from matplotlib import pyplot
2019-12-15 04:31:47 +00:00
import matplotlib
2019-12-13 22:10:40 +00:00
from mpl_toolkits.axisartist.axislines import SubplotZero
2019-12-13 20:17:57 +00:00
# Number learned symbols
order = 16
2019-12-13 20:17:57 +00:00
# Color map used for decision regions and points
color_map = matplotlib.cm.Dark2
2019-12-13 20:17:57 +00:00
# File in which the trained model is saved
2019-12-15 04:04:35 +00:00
input_file = 'output/constellation-order-{}.pth'.format(order)
2019-12-15 05:03:02 +00:00
# Restore model from file
model = torch.load(input_file)
2019-12-15 05:03:02 +00:00
model.eval()
2019-12-15 04:31:47 +00:00
# Setup plot
2019-12-13 22:10:40 +00:00
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
constellation = model.get_constellation()
util.plot_constellation(
ax, constellation,
model.channel, model.decoder,
grid_step=0.001, noise_samples=0
2019-12-15 06:07:05 +00:00
)
2019-12-13 20:17:57 +00:00
pyplot.show()