From a7e9dd2230b902dc78ab10a90a4c6c4da94abea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Wed, 18 Dec 2019 10:17:59 -0500 Subject: [PATCH] Save whole model to avoid definition duplicate --- plot.py | 17 +++++------------ train.py | 16 +++++++++++----- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/plot.py b/plot.py index b33be3c..b5ec18e 100644 --- a/plot.py +++ b/plot.py @@ -8,21 +8,14 @@ from mpl_toolkits.axisartist.axislines import SubplotZero # Number learned symbols order = 16 +# Color map used for decision regions and points +color_map = matplotlib.cm.Dark2 + # File in which the trained model is saved input_file = 'output/constellation-order-{}.pth'.format(order) -# Color map used for decision regions -color_map = matplotlib.cm.Dark2 - # Restore model from file -model = constellation.ConstellationNet( - order=order, - encoder_layers_sizes=(8, 4), - decoder_layers_sizes=(4, 8), - channel_model=constellation.GaussianChannel() -) - -model.load_state_dict(torch.load(input_file)) +model = torch.load(input_file) model.eval() # Setup plot @@ -34,7 +27,7 @@ constellation = model.get_constellation() util.plot_constellation( ax, constellation, model.channel, model.decoder, - grid_step=0.001, noise_samples=5000 + grid_step=0.001, noise_samples=0 ) pyplot.show() diff --git a/train.py b/train.py index 24d0155..1c2b950 100644 --- a/train.py +++ b/train.py @@ -3,12 +3,16 @@ from constellation import util import torch from matplotlib import pyplot from mpl_toolkits.axisartist.axislines import SubplotZero +import warnings -torch.manual_seed(57) +torch.manual_seed(42) # Number of symbols to learn order = 16 +# Shape of the hidden layers +hidden_layers = (8, 4,) + # Initial value for the learning rate initial_learning_rate = 0.1 @@ -33,9 +37,8 @@ pyplot.show(block=False) # Train the model with random data model = constellation.ConstellationNet( order=order, - encoder_layers_sizes=(8, 4,), - decoder_layers_sizes=(4, 8,), - channel_model=constellation.GaussianChannel() + encoder_layers=hidden_layers, + decoder_layers=hidden_layers[::-1], ) print('Starting training\n') @@ -122,4 +125,7 @@ with torch.no_grad(): print('\nFinished training') print('Final loss is {}'.format(final_loss)) print('Saving model as {}'.format(output_file)) -torch.save(model.state_dict(), output_file) + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + torch.save(model, output_file)