2019-12-13 20:17:57 +00:00
|
|
|
import constellation
|
|
|
|
from constellation import util
|
|
|
|
import torch
|
|
|
|
from matplotlib import pyplot
|
2019-12-15 04:31:47 +00:00
|
|
|
import matplotlib
|
2019-12-13 22:10:40 +00:00
|
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
2019-12-13 20:17:57 +00:00
|
|
|
|
|
|
|
# Number learned symbols
|
2019-12-16 07:30:05 +00:00
|
|
|
order = 16
|
2019-12-13 20:17:57 +00:00
|
|
|
|
|
|
|
# File in which the trained model is saved
|
2019-12-15 04:04:35 +00:00
|
|
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
|
2019-12-15 04:31:47 +00:00
|
|
|
# Color map used for decision regions
|
2019-12-15 07:07:52 +00:00
|
|
|
color_map = matplotlib.cm.Dark2
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 05:03:02 +00:00
|
|
|
# Restore model from file
|
|
|
|
model = constellation.ConstellationNet(
|
|
|
|
order=order,
|
2019-12-16 07:30:05 +00:00
|
|
|
encoder_layers_sizes=(8, 4),
|
|
|
|
decoder_layers_sizes=(4, 8),
|
2019-12-15 05:03:02 +00:00
|
|
|
channel_model=constellation.GaussianChannel()
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(input_file))
|
|
|
|
model.eval()
|
|
|
|
|
2019-12-15 04:31:47 +00:00
|
|
|
# Setup plot
|
2019-12-13 22:10:40 +00:00
|
|
|
fig = pyplot.figure()
|
|
|
|
ax = SubplotZero(fig, 111)
|
|
|
|
fig.add_subplot(ax)
|
|
|
|
|
2019-12-16 01:05:20 +00:00
|
|
|
constellation = model.get_constellation()
|
|
|
|
util.plot_constellation(
|
|
|
|
ax, constellation,
|
|
|
|
model.channel, model.decoder,
|
2019-12-16 14:22:24 +00:00
|
|
|
grid_step=0.001, noise_samples=5000
|
2019-12-15 06:07:05 +00:00
|
|
|
)
|
|
|
|
|
2019-12-13 20:17:57 +00:00
|
|
|
pyplot.show()
|