|
|
@ -3,12 +3,16 @@ from constellation import util |
|
|
|
import torch |
|
|
|
from matplotlib import pyplot |
|
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero |
|
|
|
import warnings |
|
|
|
|
|
|
|
torch.manual_seed(57) |
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
# Number of symbols to learn |
|
|
|
order = 16 |
|
|
|
|
|
|
|
# Shape of the hidden layers |
|
|
|
hidden_layers = (8, 4,) |
|
|
|
|
|
|
|
# Initial value for the learning rate |
|
|
|
initial_learning_rate = 0.1 |
|
|
|
|
|
|
@ -33,9 +37,8 @@ pyplot.show(block=False) |
|
|
|
# Train the model with random data |
|
|
|
model = constellation.ConstellationNet( |
|
|
|
order=order, |
|
|
|
encoder_layers_sizes=(8, 4,), |
|
|
|
decoder_layers_sizes=(4, 8,), |
|
|
|
channel_model=constellation.GaussianChannel() |
|
|
|
encoder_layers=hidden_layers, |
|
|
|
decoder_layers=hidden_layers[::-1], |
|
|
|
) |
|
|
|
|
|
|
|
print('Starting training\n') |
|
|
@ -122,4 +125,7 @@ with torch.no_grad(): |
|
|
|
print('\nFinished training') |
|
|
|
print('Final loss is {}'.format(final_loss)) |
|
|
|
print('Saving model as {}'.format(output_file)) |
|
|
|
torch.save(model.state_dict(), output_file) |
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
|
warnings.simplefilter('ignore') |
|
|
|
torch.save(model, output_file) |
|
|
|