From a01d83f33996d2d976492772465ee03570f780c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Fri, 13 Dec 2019 17:10:40 -0500 Subject: [PATCH] Improve plot --- plot.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- train.py | 27 ++++++++++---------- 2 files changed, 87 insertions(+), 17 deletions(-) diff --git a/plot.py b/plot.py index 7c0bd9b..65d9041 100644 --- a/plot.py +++ b/plot.py @@ -2,12 +2,15 @@ import constellation from constellation import util import torch from matplotlib import pyplot +from mpl_toolkits.axisartist.axislines import SubplotZero +import math +import numpy # Number learned symbols order = 4 # File in which the trained model is saved -input_file = 'output/constellation-net.tc' +input_file = 'output/constellation-order-{}.tc'.format(order) model = constellation.ConstellationNet(order=order) model.load_state_dict(torch.load(input_file)) @@ -19,8 +22,74 @@ with torch.no_grad(): torch.arange(0, order), order ) - ).tolist() + ) + +fig = pyplot.figure() +ax = SubplotZero(fig, 111) +fig.add_subplot(ax) + +# Extend axes symmetrically around zero so that they fit data +extent = max( + abs(encoded_vectors.min()), + abs(encoded_vectors.max()) +) * 1.05 + +ax.set_xlim(-extent, extent) +ax.set_ylim(-extent, extent) + +# Hide borders +for direction in ['left', 'bottom', 'right', 'top']: + ax.axis[direction].set_visible(False) + +# Show zero-centered axes +for direction in ['xzero', 'yzero']: + axis = ax.axis[direction] + axis.set_visible(True) + axis.set_axisline_style("-|>") + +# Configure axes ticks and labels +ax.annotate( + 'I', (1, 0.5), xycoords='axes fraction', + xytext=(25, 0), textcoords='offset points', + va='center', ha='right' +) + +ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') +ax.axis['xzero'].major_ticklabels.set_ha('center') +ax.axis['xzero'].major_ticklabels.set_va('top') + +ax.annotate( + 'Q', (0.5, 1), xycoords='axes fraction', + xytext=(0, 25), textcoords='offset points', + va='center', ha='center' +) + +ax.axis['yzero'].major_ticklabels.set_rotation(-90) +ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') +ax.axis['yzero'].major_ticklabels.set_ha('left') +ax.axis['yzero'].major_ticklabels.set_va('center') + +# Add a single tick on 0 +ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0]) +ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0]) + +ax.annotate( + '0', (0, 0), + xytext=(15, -10), textcoords='offset points', + va='center', ha='center' +) + +ax.grid() + +# Plot encoded vectors +ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10) + +# Add index label for each vector +for row in range(order): + ax.annotate( + row + 1, encoded_vectors[row], + xytext=(5, 5), textcoords='offset points', + backgroundcolor='white', zorder=9 + ) -fig, axis = pyplot.subplots() -axis.scatter(*zip(*encoded_vectors)) pyplot.show() diff --git a/train.py b/train.py index 214bc11..3d1eacd 100644 --- a/train.py +++ b/train.py @@ -15,11 +15,14 @@ num_epochs = 20000 loss_report_epoch_skip = 500 # File in which the trained model is saved -output_file = 'output/constellation-net.tc' - -print('Starting training with {} epochs\n'.format(num_epochs)) +output_file = 'output/constellation-order-{}.tc'.format(order) model = constellation.ConstellationNet(order=order) + +# Train the model with random data +model.train() +print('Starting training with {} epochs\n'.format(num_epochs)) + criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) @@ -46,24 +49,22 @@ for epoch in range(num_epochs): print('\nFinished training\n') # Print some examples of reconstruction -with torch.no_grad(): - num_examples = 5 +model.eval() +print('Reconstruction examples:') +print('Input vector\t\t\tOutput vector after softmax') - classes_example = util.get_random_messages(num_examples, order) - onehot_example = util.messages_to_onehot(classes_example, order) +with torch.no_grad(): + onehot_example = util.messages_to_onehot(torch.arange(0, order), order) raw_output = model(onehot_example) raw_output.required_grad = False reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) - print('Reconstruction examples:') - print('Input vector\t\t\tOutput vector after softmax') - - for example_index in range(num_examples): + for index in range(order): print('{}\t\t{}'.format( - onehot_example[example_index].tolist(), + onehot_example[index].tolist(), '[{}]'.format(', '.join( '{:.5f}'.format(x) - for x in reconstructed_example[example_index].tolist() + for x in reconstructed_example[index].tolist() )) ))