From 8f6363ee219b04872e64182dad4b273d7ba5a701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Sat, 14 Dec 2019 23:31:47 -0500 Subject: [PATCH] Plot decision regions from the decoder --- plot.py | 77 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/plot.py b/plot.py index 6e495f2..2442262 100644 --- a/plot.py +++ b/plot.py @@ -2,6 +2,7 @@ import constellation from constellation import util import torch from matplotlib import pyplot +import matplotlib from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols @@ -10,36 +11,16 @@ order = 4 # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) -# Restore model from file -model = constellation.ConstellationNet( - order=order, - encoder_layers_sizes=(4,), - decoder_layers_sizes=(4,), - channel_model=constellation.GaussianChannel() -) - -model.load_state_dict(torch.load(input_file)) -model.eval() - -# Compute encoded vectors -with torch.no_grad(): - encoded_vectors = model.encoder( - util.messages_to_onehot( - torch.arange(0, order), - order - ) - ) +# Color map used for decision regions +color_map = matplotlib.cm.Set1 +# Setup plot fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) # Extend axes symmetrically around zero so that they fit data -extent = max( - abs(encoded_vectors.min()), - abs(encoded_vectors.max()) -) * 1.05 - +extent = 1.5 ax.set_xlim(-extent, extent) ax.set_ylim(-extent, extent) @@ -87,10 +68,52 @@ ax.annotate( ax.grid() -# Plot encoded vectors -ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10) +# Restore model from file +model = constellation.ConstellationNet( + order=order, + encoder_layers_sizes=(4,), + decoder_layers_sizes=(4,), + channel_model=constellation.GaussianChannel() +) + +model.load_state_dict(torch.load(input_file)) +model.eval() + +# Plot decision regions +color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) + +step = 0.01 +grid_range = torch.arange(-extent, extent, step) +grid_y, grid_x = torch.meshgrid(grid_range, grid_range) +grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) + +ax.imshow( + grid_images, extent=(-extent, extent, -extent, extent), + aspect="auto", + origin="lower", + cmap=color_map, + norm=color_norm, + alpha=0.15 +) + +# Plot encoded vectors +with torch.no_grad(): + encoded_vectors = model.encoder( + util.messages_to_onehot( + torch.arange(0, order), + order + ) + ) + +ax.scatter( + *zip(*encoded_vectors.tolist()), + zorder=10, + c=range(len(encoded_vectors)), + edgecolor='black', + cmap=color_map, + norm=color_norm, +) -# Add index label for each vector for row in range(order): ax.annotate( row + 1, encoded_vectors[row],