102 lines
2.5 KiB
Python
102 lines
2.5 KiB
Python
import constellation
|
|
from constellation import util
|
|
import torch
|
|
from matplotlib import pyplot
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
|
|
|
# Number learned symbols
|
|
order = 4
|
|
|
|
# File in which the trained model is saved
|
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
# Restore model from file
|
|
model = constellation.ConstellationNet(
|
|
order=order,
|
|
encoder_layers_sizes=(4,),
|
|
decoder_layers_sizes=(4,),
|
|
channel_model=constellation.GaussianChannel()
|
|
)
|
|
|
|
model.load_state_dict(torch.load(input_file))
|
|
model.eval()
|
|
|
|
# Compute encoded vectors
|
|
with torch.no_grad():
|
|
encoded_vectors = model.encoder(
|
|
util.messages_to_onehot(
|
|
torch.arange(0, order),
|
|
order
|
|
)
|
|
)
|
|
|
|
fig = pyplot.figure()
|
|
ax = SubplotZero(fig, 111)
|
|
fig.add_subplot(ax)
|
|
|
|
# Extend axes symmetrically around zero so that they fit data
|
|
extent = max(
|
|
abs(encoded_vectors.min()),
|
|
abs(encoded_vectors.max())
|
|
) * 1.05
|
|
|
|
ax.set_xlim(-extent, extent)
|
|
ax.set_ylim(-extent, extent)
|
|
|
|
# Hide borders
|
|
for direction in ['left', 'bottom', 'right', 'top']:
|
|
ax.axis[direction].set_visible(False)
|
|
|
|
# Show zero-centered axes
|
|
for direction in ['xzero', 'yzero']:
|
|
axis = ax.axis[direction]
|
|
axis.set_visible(True)
|
|
axis.set_axisline_style("-|>")
|
|
|
|
# Configure axes ticks and labels
|
|
ax.annotate(
|
|
'I', (1, 0.5), xycoords='axes fraction',
|
|
xytext=(25, 0), textcoords='offset points',
|
|
va='center', ha='right'
|
|
)
|
|
|
|
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
|
|
ax.axis['xzero'].major_ticklabels.set_ha('center')
|
|
ax.axis['xzero'].major_ticklabels.set_va('top')
|
|
|
|
ax.annotate(
|
|
'Q', (0.5, 1), xycoords='axes fraction',
|
|
xytext=(0, 25), textcoords='offset points',
|
|
va='center', ha='center'
|
|
)
|
|
|
|
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
|
|
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
|
|
ax.axis['yzero'].major_ticklabels.set_ha('left')
|
|
ax.axis['yzero'].major_ticklabels.set_va('center')
|
|
|
|
# Add a single tick on 0
|
|
ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0])
|
|
ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0])
|
|
|
|
ax.annotate(
|
|
'0', (0, 0),
|
|
xytext=(15, -10), textcoords='offset points',
|
|
va='center', ha='center'
|
|
)
|
|
|
|
ax.grid()
|
|
|
|
# Plot encoded vectors
|
|
ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10)
|
|
|
|
# Add index label for each vector
|
|
for row in range(order):
|
|
ax.annotate(
|
|
row + 1, encoded_vectors[row],
|
|
xytext=(5, 5), textcoords='offset points',
|
|
backgroundcolor='white', zorder=9
|
|
)
|
|
|
|
pyplot.show()
|