41 lines
913 B
Python
41 lines
913 B
Python
import constellation
|
|
from constellation import util
|
|
import torch
|
|
from matplotlib import pyplot
|
|
import matplotlib
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
|
|
|
# Number learned symbols
|
|
order = 16
|
|
|
|
# File in which the trained model is saved
|
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
# Color map used for decision regions
|
|
color_map = matplotlib.cm.Dark2
|
|
|
|
# Restore model from file
|
|
model = constellation.ConstellationNet(
|
|
order=order,
|
|
encoder_layers_sizes=(8, 4),
|
|
decoder_layers_sizes=(4, 8),
|
|
channel_model=constellation.GaussianChannel()
|
|
)
|
|
|
|
model.load_state_dict(torch.load(input_file))
|
|
model.eval()
|
|
|
|
# Setup plot
|
|
fig = pyplot.figure()
|
|
ax = SubplotZero(fig, 111)
|
|
fig.add_subplot(ax)
|
|
|
|
constellation = model.get_constellation()
|
|
util.plot_constellation(
|
|
ax, constellation,
|
|
model.channel, model.decoder,
|
|
grid_step=0.001, noise_samples=5000
|
|
)
|
|
|
|
pyplot.show()
|