diff --git a/ConstellationNet.py b/constellation/ConstellationNet.py similarity index 100% rename from ConstellationNet.py rename to constellation/ConstellationNet.py diff --git a/constellation/__init__.py b/constellation/__init__.py new file mode 100644 index 0000000..e461218 --- /dev/null +++ b/constellation/__init__.py @@ -0,0 +1,2 @@ +from constellation.ConstellationNet import ConstellationNet +import constellation.util diff --git a/constellation/util.py b/constellation/util.py new file mode 100644 index 0000000..3b4197f --- /dev/null +++ b/constellation/util.py @@ -0,0 +1,36 @@ +import torch + + +def get_random_messages(count, order): + """ + Generate a list of messages. + + :param count: Number of messages to generate. + :param order: Number of possible messages. + :return: One-dimensional vector with each entry being the index of the + generated message which is between 0, inclusive, and `order`, exclusive. + + >>> get_random_messages(5) + torch.tensor([0, 2, 0, 3, 4]) + """ + return torch.randint(0, order, (count,)) + + +def messages_to_onehot(messages, order): + """ + Convert messages represented as indexes to one-hot encoding. + + :param messages: List of messages to convert. + :param order: Number of possible messages. + :return: One-hot encoded messages. + + >>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4])) + torch.tensor([ + [1., 0., 0., 0., 0.], + [0., 0., 1., 0., 0.], + [1., 0., 0., 0., 0.], + [0., 0., 0., 1., 0.], + [0., 0., 0., 0., 1.], + ]) + """ + return torch.nn.functional.one_hot(messages, num_classes=order).float() diff --git a/output/.gitignore b/output/.gitignore new file mode 100644 index 0000000..a26b325 --- /dev/null +++ b/output/.gitignore @@ -0,0 +1,5 @@ +# Ignore all files in this directory… +* + +# …except for this one. +!.gitignore diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..7c0bd9b --- /dev/null +++ b/plot.py @@ -0,0 +1,26 @@ +import constellation +from constellation import util +import torch +from matplotlib import pyplot + +# Number learned symbols +order = 4 + +# File in which the trained model is saved +input_file = 'output/constellation-net.tc' + +model = constellation.ConstellationNet(order=order) +model.load_state_dict(torch.load(input_file)) + +# Compute encoded vectors +with torch.no_grad(): + encoded_vectors = model.encoder( + util.messages_to_onehot( + torch.arange(0, order), + order + ) + ).tolist() + +fig, axis = pyplot.subplots() +axis.scatter(*zip(*encoded_vectors)) +pyplot.show() diff --git a/train.py b/train.py index ad571eb..214bc11 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ -from ConstellationNet import ConstellationNet +import constellation +from constellation import util import torch # Number of symbols to learn @@ -8,23 +9,25 @@ order = 4 epoch_size = 10000 # Number of epochs -num_epochs = 25000 +num_epochs = 20000 # Number of epochs to skip between every loss report -loss_report_epoch_skip = 200 +loss_report_epoch_skip = 500 -model = ConstellationNet(order=order) +# File in which the trained model is saved +output_file = 'output/constellation-net.tc' + +print('Starting training with {} epochs\n'.format(num_epochs)) + +model = constellation.ConstellationNet(order=order) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 for epoch in range(num_epochs): - classes_dataset = torch.randint(0, order, (epoch_size,)) - onehot_dataset = torch.nn.functional.one_hot( - classes_dataset, - num_classes=order - ).float() + classes_dataset = util.get_random_messages(epoch_size, order) + onehot_dataset = util.messages_to_onehot(classes_dataset, order) optimizer.zero_grad() predictions = model(onehot_dataset) @@ -40,8 +43,29 @@ for epoch in range(num_epochs): print('Loss is {}'.format(running_loss)) running_loss = 0 -# Test the model with class 1 -print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float())) +print('\nFinished training\n') -# Test the model with class 2 -print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float())) +# Print some examples of reconstruction +with torch.no_grad(): + num_examples = 5 + + classes_example = util.get_random_messages(num_examples, order) + onehot_example = util.messages_to_onehot(classes_example, order) + raw_output = model(onehot_example) + raw_output.required_grad = False + reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) + + print('Reconstruction examples:') + print('Input vector\t\t\tOutput vector after softmax') + + for example_index in range(num_examples): + print('{}\t\t{}'.format( + onehot_example[example_index].tolist(), + '[{}]'.format(', '.join( + '{:.5f}'.format(x) + for x in reconstructed_example[example_index].tolist() + )) + )) + +print('\nSaving model as {}'.format(output_file)) +torch.save(model.state_dict(), output_file)