Save whole model to avoid definition duplicate

This commit is contained in:
Mattéo Delabre 2019-12-18 10:17:59 -05:00
parent 0769a61fcf
commit a7e9dd2230
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
2 changed files with 16 additions and 17 deletions

17
plot.py
View File

@ -8,21 +8,14 @@ from mpl_toolkits.axisartist.axislines import SubplotZero
# Number learned symbols
order = 16
# Color map used for decision regions and points
color_map = matplotlib.cm.Dark2
# File in which the trained model is saved
input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions
color_map = matplotlib.cm.Dark2
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(8, 4),
decoder_layers_sizes=(4, 8),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model = torch.load(input_file)
model.eval()
# Setup plot
@ -34,7 +27,7 @@ constellation = model.get_constellation()
util.plot_constellation(
ax, constellation,
model.channel, model.decoder,
grid_step=0.001, noise_samples=5000
grid_step=0.001, noise_samples=0
)
pyplot.show()

View File

@ -3,12 +3,16 @@ from constellation import util
import torch
from matplotlib import pyplot
from mpl_toolkits.axisartist.axislines import SubplotZero
import warnings
torch.manual_seed(57)
torch.manual_seed(42)
# Number of symbols to learn
order = 16
# Shape of the hidden layers
hidden_layers = (8, 4,)
# Initial value for the learning rate
initial_learning_rate = 0.1
@ -33,9 +37,8 @@ pyplot.show(block=False)
# Train the model with random data
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(8, 4,),
decoder_layers_sizes=(4, 8,),
channel_model=constellation.GaussianChannel()
encoder_layers=hidden_layers,
decoder_layers=hidden_layers[::-1],
)
print('Starting training\n')
@ -122,4 +125,7 @@ with torch.no_grad():
print('\nFinished training')
print('Final loss is {}'.format(final_loss))
print('Saving model as {}'.format(output_file))
torch.save(model.state_dict(), output_file)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
torch.save(model, output_file)