Save whole model to avoid definition duplicate
This commit is contained in:
parent
0769a61fcf
commit
a7e9dd2230
17
plot.py
17
plot.py
|
@ -8,21 +8,14 @@ from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
# Number learned symbols
|
# Number learned symbols
|
||||||
order = 16
|
order = 16
|
||||||
|
|
||||||
|
# Color map used for decision regions and points
|
||||||
|
color_map = matplotlib.cm.Dark2
|
||||||
|
|
||||||
# File in which the trained model is saved
|
# File in which the trained model is saved
|
||||||
input_file = 'output/constellation-order-{}.pth'.format(order)
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
|
|
||||||
# Color map used for decision regions
|
|
||||||
color_map = matplotlib.cm.Dark2
|
|
||||||
|
|
||||||
# Restore model from file
|
# Restore model from file
|
||||||
model = constellation.ConstellationNet(
|
model = torch.load(input_file)
|
||||||
order=order,
|
|
||||||
encoder_layers_sizes=(8, 4),
|
|
||||||
decoder_layers_sizes=(4, 8),
|
|
||||||
channel_model=constellation.GaussianChannel()
|
|
||||||
)
|
|
||||||
|
|
||||||
model.load_state_dict(torch.load(input_file))
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Setup plot
|
# Setup plot
|
||||||
|
@ -34,7 +27,7 @@ constellation = model.get_constellation()
|
||||||
util.plot_constellation(
|
util.plot_constellation(
|
||||||
ax, constellation,
|
ax, constellation,
|
||||||
model.channel, model.decoder,
|
model.channel, model.decoder,
|
||||||
grid_step=0.001, noise_samples=5000
|
grid_step=0.001, noise_samples=0
|
||||||
)
|
)
|
||||||
|
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
|
|
16
train.py
16
train.py
|
@ -3,12 +3,16 @@ from constellation import util
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
|
import warnings
|
||||||
|
|
||||||
torch.manual_seed(57)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 16
|
order = 16
|
||||||
|
|
||||||
|
# Shape of the hidden layers
|
||||||
|
hidden_layers = (8, 4,)
|
||||||
|
|
||||||
# Initial value for the learning rate
|
# Initial value for the learning rate
|
||||||
initial_learning_rate = 0.1
|
initial_learning_rate = 0.1
|
||||||
|
|
||||||
|
@ -33,9 +37,8 @@ pyplot.show(block=False)
|
||||||
# Train the model with random data
|
# Train the model with random data
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(8, 4,),
|
encoder_layers=hidden_layers,
|
||||||
decoder_layers_sizes=(4, 8,),
|
decoder_layers=hidden_layers[::-1],
|
||||||
channel_model=constellation.GaussianChannel()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print('Starting training\n')
|
print('Starting training\n')
|
||||||
|
@ -122,4 +125,7 @@ with torch.no_grad():
|
||||||
print('\nFinished training')
|
print('\nFinished training')
|
||||||
print('Final loss is {}'.format(final_loss))
|
print('Final loss is {}'.format(final_loss))
|
||||||
print('Saving model as {}'.format(output_file))
|
print('Saving model as {}'.format(output_file))
|
||||||
torch.save(model.state_dict(), output_file)
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
torch.save(model, output_file)
|
||||||
|
|
Loading…
Reference in New Issue