2019-12-13 20:17:57 +00:00
|
|
|
import constellation
|
|
|
|
from constellation import util
|
|
|
|
import torch
|
|
|
|
from matplotlib import pyplot
|
2019-12-15 04:31:47 +00:00
|
|
|
import matplotlib
|
2019-12-13 22:10:40 +00:00
|
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
2019-12-13 20:17:57 +00:00
|
|
|
|
|
|
|
# Number learned symbols
|
|
|
|
order = 4
|
|
|
|
|
|
|
|
# File in which the trained model is saved
|
2019-12-15 04:04:35 +00:00
|
|
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
|
2019-12-15 04:31:47 +00:00
|
|
|
# Color map used for decision regions
|
2019-12-15 07:07:52 +00:00
|
|
|
color_map = matplotlib.cm.Dark2
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 05:03:02 +00:00
|
|
|
# Restore model from file
|
|
|
|
model = constellation.ConstellationNet(
|
|
|
|
order=order,
|
|
|
|
encoder_layers_sizes=(4,),
|
|
|
|
decoder_layers_sizes=(4,),
|
|
|
|
channel_model=constellation.GaussianChannel()
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(input_file))
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
# Extract encoding
|
|
|
|
with torch.no_grad():
|
|
|
|
encoded_vectors = model.encoder(
|
|
|
|
util.messages_to_onehot(
|
|
|
|
torch.arange(0, order),
|
|
|
|
order
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2019-12-15 04:31:47 +00:00
|
|
|
# Setup plot
|
2019-12-13 22:10:40 +00:00
|
|
|
fig = pyplot.figure()
|
|
|
|
ax = SubplotZero(fig, 111)
|
|
|
|
fig.add_subplot(ax)
|
|
|
|
|
|
|
|
# Extend axes symmetrically around zero so that they fit data
|
2019-12-15 06:07:05 +00:00
|
|
|
axis_extent = max(
|
2019-12-15 05:03:02 +00:00
|
|
|
abs(encoded_vectors.min()),
|
|
|
|
abs(encoded_vectors.max())
|
|
|
|
) * 1.05
|
2019-12-15 06:07:05 +00:00
|
|
|
ax.set_xlim(-axis_extent, axis_extent)
|
|
|
|
ax.set_ylim(-axis_extent, axis_extent)
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 07:07:52 +00:00
|
|
|
# Hide borders but keep ticks
|
2019-12-13 22:10:40 +00:00
|
|
|
for direction in ['left', 'bottom', 'right', 'top']:
|
2019-12-15 07:07:52 +00:00
|
|
|
ax.axis[direction].line.set_color('#00000000')
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 07:07:52 +00:00
|
|
|
# Show zero-centered axes without ticks
|
2019-12-13 22:10:40 +00:00
|
|
|
for direction in ['xzero', 'yzero']:
|
|
|
|
axis = ax.axis[direction]
|
|
|
|
axis.set_visible(True)
|
2019-12-15 07:07:52 +00:00
|
|
|
axis.set_axisline_style('-|>')
|
|
|
|
axis.major_ticklabels.set_visible(False)
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 07:07:52 +00:00
|
|
|
# Add axis names
|
2019-12-13 22:10:40 +00:00
|
|
|
ax.annotate(
|
|
|
|
'I', (1, 0.5), xycoords='axes fraction',
|
|
|
|
xytext=(25, 0), textcoords='offset points',
|
|
|
|
va='center', ha='right'
|
|
|
|
)
|
|
|
|
|
|
|
|
ax.annotate(
|
|
|
|
'Q', (0.5, 1), xycoords='axes fraction',
|
|
|
|
xytext=(0, 25), textcoords='offset points',
|
|
|
|
va='center', ha='center'
|
|
|
|
)
|
|
|
|
|
|
|
|
ax.grid()
|
|
|
|
|
2019-12-15 04:31:47 +00:00
|
|
|
# Plot decision regions
|
|
|
|
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
|
|
|
|
2019-12-15 06:07:05 +00:00
|
|
|
regions_extent = 2 * axis_extent
|
|
|
|
step = 0.001 * regions_extent
|
|
|
|
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
2019-12-15 04:31:47 +00:00
|
|
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
|
|
|
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
|
|
|
|
|
|
|
ax.imshow(
|
2019-12-15 06:07:05 +00:00
|
|
|
grid_images,
|
|
|
|
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
|
2019-12-15 07:07:52 +00:00
|
|
|
aspect='auto',
|
|
|
|
origin='lower',
|
2019-12-15 04:31:47 +00:00
|
|
|
cmap=color_map,
|
|
|
|
norm=color_norm,
|
2019-12-15 06:21:03 +00:00
|
|
|
alpha=0.1
|
2019-12-15 04:31:47 +00:00
|
|
|
)
|
|
|
|
|
2019-12-13 22:10:40 +00:00
|
|
|
# Plot encoded vectors
|
2019-12-15 04:31:47 +00:00
|
|
|
ax.scatter(
|
|
|
|
*zip(*encoded_vectors.tolist()),
|
|
|
|
zorder=10,
|
2019-12-15 07:07:52 +00:00
|
|
|
s=60,
|
2019-12-15 04:31:47 +00:00
|
|
|
c=range(len(encoded_vectors)),
|
|
|
|
edgecolor='black',
|
|
|
|
cmap=color_map,
|
|
|
|
norm=color_norm,
|
|
|
|
)
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-15 06:07:05 +00:00
|
|
|
# Plot noise
|
|
|
|
noisy_count = 1000
|
|
|
|
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
|
|
|
|
|
|
|
|
ax.scatter(
|
|
|
|
*zip(*noisy_vectors.tolist()),
|
2019-12-15 07:07:52 +00:00
|
|
|
marker='.',
|
|
|
|
s=5,
|
2019-12-15 06:07:05 +00:00
|
|
|
c=list(range(len(encoded_vectors))) * noisy_count,
|
|
|
|
cmap=color_map,
|
|
|
|
norm=color_norm,
|
2019-12-15 06:21:03 +00:00
|
|
|
alpha=0.7,
|
2019-12-15 06:07:05 +00:00
|
|
|
zorder=8
|
|
|
|
)
|
|
|
|
|
2019-12-13 20:17:57 +00:00
|
|
|
pyplot.show()
|