|
|
@ -1,4 +1,5 @@ |
|
|
|
from ConstellationNet import ConstellationNet |
|
|
|
import constellation |
|
|
|
from constellation import util |
|
|
|
import torch |
|
|
|
|
|
|
|
# Number of symbols to learn |
|
|
@ -8,23 +9,25 @@ order = 4 |
|
|
|
epoch_size = 10000 |
|
|
|
|
|
|
|
# Number of epochs |
|
|
|
num_epochs = 25000 |
|
|
|
num_epochs = 20000 |
|
|
|
|
|
|
|
# Number of epochs to skip between every loss report |
|
|
|
loss_report_epoch_skip = 200 |
|
|
|
loss_report_epoch_skip = 500 |
|
|
|
|
|
|
|
model = ConstellationNet(order=order) |
|
|
|
# File in which the trained model is saved |
|
|
|
output_file = 'output/constellation-net.tc' |
|
|
|
|
|
|
|
print('Starting training with {} epochs\n'.format(num_epochs)) |
|
|
|
|
|
|
|
model = constellation.ConstellationNet(order=order) |
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
optimizer = torch.optim.Adam(model.parameters()) |
|
|
|
|
|
|
|
running_loss = 0 |
|
|
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
classes_dataset = torch.randint(0, order, (epoch_size,)) |
|
|
|
onehot_dataset = torch.nn.functional.one_hot( |
|
|
|
classes_dataset, |
|
|
|
num_classes=order |
|
|
|
).float() |
|
|
|
classes_dataset = util.get_random_messages(epoch_size, order) |
|
|
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order) |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
predictions = model(onehot_dataset) |
|
|
@ -40,8 +43,29 @@ for epoch in range(num_epochs): |
|
|
|
print('Loss is {}'.format(running_loss)) |
|
|
|
running_loss = 0 |
|
|
|
|
|
|
|
# Test the model with class 1 |
|
|
|
print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float())) |
|
|
|
print('\nFinished training\n') |
|
|
|
|
|
|
|
# Print some examples of reconstruction |
|
|
|
with torch.no_grad(): |
|
|
|
num_examples = 5 |
|
|
|
|
|
|
|
classes_example = util.get_random_messages(num_examples, order) |
|
|
|
onehot_example = util.messages_to_onehot(classes_example, order) |
|
|
|
raw_output = model(onehot_example) |
|
|
|
raw_output.required_grad = False |
|
|
|
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) |
|
|
|
|
|
|
|
print('Reconstruction examples:') |
|
|
|
print('Input vector\t\t\tOutput vector after softmax') |
|
|
|
|
|
|
|
for example_index in range(num_examples): |
|
|
|
print('{}\t\t{}'.format( |
|
|
|
onehot_example[example_index].tolist(), |
|
|
|
'[{}]'.format(', '.join( |
|
|
|
'{:.5f}'.format(x) |
|
|
|
for x in reconstructed_example[example_index].tolist() |
|
|
|
)) |
|
|
|
)) |
|
|
|
|
|
|
|
# Test the model with class 2 |
|
|
|
print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float())) |
|
|
|
print('\nSaving model as {}'.format(output_file)) |
|
|
|
torch.save(model.state_dict(), output_file) |
|
|
|