Plot decision regions from the decoder

This commit is contained in:
Mattéo Delabre 2019-12-14 23:31:47 -05:00
parent 3b40e27070
commit 8f6363ee21
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 50 additions and 27 deletions

77
plot.py
View File

@ -2,6 +2,7 @@ import constellation
from constellation import util
import torch
from matplotlib import pyplot
import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero
# Number learned symbols
@ -10,36 +11,16 @@ order = 4
# File in which the trained model is saved
input_file = 'output/constellation-order-{}.pth'.format(order)
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Compute encoded vectors
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
# Color map used for decision regions
color_map = matplotlib.cm.Set1
# Setup plot
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data
extent = max(
abs(encoded_vectors.min()),
abs(encoded_vectors.max())
) * 1.05
extent = 1.5
ax.set_xlim(-extent, extent)
ax.set_ylim(-extent, extent)
@ -87,10 +68,52 @@ ax.annotate(
ax.grid()
# Plot encoded vectors
ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10)
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
step = 0.01
grid_range = torch.arange(-extent, extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow(
grid_images, extent=(-extent, extent, -extent, extent),
aspect="auto",
origin="lower",
cmap=color_map,
norm=color_norm,
alpha=0.15
)
# Plot encoded vectors
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
ax.scatter(
*zip(*encoded_vectors.tolist()),
zorder=10,
c=range(len(encoded_vectors)),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Add index label for each vector
for row in range(order):
ax.annotate(
row + 1, encoded_vectors[row],