constellationnet/plot.py

122 lines
2.8 KiB
Python

import constellation
from constellation import util
import torch
from matplotlib import pyplot
import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero
# Number learned symbols
order = 4
# File in which the trained model is saved
input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions
color_map = matplotlib.cm.Dark2
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Extract encoding
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
# Setup plot
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data
axis_extent = max(
abs(encoded_vectors.min()),
abs(encoded_vectors.max())
) * 1.05
ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent)
# Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']:
axis = ax.axis[direction]
axis.set_visible(True)
axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Add axis names
ax.annotate(
'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points',
va='center', ha='right'
)
ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points',
va='center', ha='center'
)
ax.grid()
# Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
regions_extent = 2 * axis_extent
step = 0.001 * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow(
grid_images,
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
aspect='auto',
origin='lower',
cmap=color_map,
norm=color_norm,
alpha=0.1
)
# Plot encoded vectors
ax.scatter(
*zip(*encoded_vectors.tolist()),
zorder=10,
s=60,
c=range(len(encoded_vectors)),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Plot noise
noisy_count = 1000
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
ax.scatter(
*zip(*noisy_vectors.tolist()),
marker='.',
s=5,
c=list(range(len(encoded_vectors))) * noisy_count,
cmap=color_map,
norm=color_norm,
alpha=0.7,
zorder=8
)
pyplot.show()