Save trained models and plot encoding
This commit is contained in:
parent
2af8354a07
commit
49e63775dd
|
@ -0,0 +1,2 @@
|
||||||
|
from constellation.ConstellationNet import ConstellationNet
|
||||||
|
import constellation.util
|
|
@ -0,0 +1,36 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_messages(count, order):
|
||||||
|
"""
|
||||||
|
Generate a list of messages.
|
||||||
|
|
||||||
|
:param count: Number of messages to generate.
|
||||||
|
:param order: Number of possible messages.
|
||||||
|
:return: One-dimensional vector with each entry being the index of the
|
||||||
|
generated message which is between 0, inclusive, and `order`, exclusive.
|
||||||
|
|
||||||
|
>>> get_random_messages(5)
|
||||||
|
torch.tensor([0, 2, 0, 3, 4])
|
||||||
|
"""
|
||||||
|
return torch.randint(0, order, (count,))
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_onehot(messages, order):
|
||||||
|
"""
|
||||||
|
Convert messages represented as indexes to one-hot encoding.
|
||||||
|
|
||||||
|
:param messages: List of messages to convert.
|
||||||
|
:param order: Number of possible messages.
|
||||||
|
:return: One-hot encoded messages.
|
||||||
|
|
||||||
|
>>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4]))
|
||||||
|
torch.tensor([
|
||||||
|
[1., 0., 0., 0., 0.],
|
||||||
|
[0., 0., 1., 0., 0.],
|
||||||
|
[1., 0., 0., 0., 0.],
|
||||||
|
[0., 0., 0., 1., 0.],
|
||||||
|
[0., 0., 0., 0., 1.],
|
||||||
|
])
|
||||||
|
"""
|
||||||
|
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Ignore all files in this directory…
|
||||||
|
*
|
||||||
|
|
||||||
|
# …except for this one.
|
||||||
|
!.gitignore
|
|
@ -0,0 +1,26 @@
|
||||||
|
import constellation
|
||||||
|
from constellation import util
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot
|
||||||
|
|
||||||
|
# Number learned symbols
|
||||||
|
order = 4
|
||||||
|
|
||||||
|
# File in which the trained model is saved
|
||||||
|
input_file = 'output/constellation-net.tc'
|
||||||
|
|
||||||
|
model = constellation.ConstellationNet(order=order)
|
||||||
|
model.load_state_dict(torch.load(input_file))
|
||||||
|
|
||||||
|
# Compute encoded vectors
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_vectors = model.encoder(
|
||||||
|
util.messages_to_onehot(
|
||||||
|
torch.arange(0, order),
|
||||||
|
order
|
||||||
|
)
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
fig, axis = pyplot.subplots()
|
||||||
|
axis.scatter(*zip(*encoded_vectors))
|
||||||
|
pyplot.show()
|
50
train.py
50
train.py
|
@ -1,4 +1,5 @@
|
||||||
from ConstellationNet import ConstellationNet
|
import constellation
|
||||||
|
from constellation import util
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
|
@ -8,23 +9,25 @@ order = 4
|
||||||
epoch_size = 10000
|
epoch_size = 10000
|
||||||
|
|
||||||
# Number of epochs
|
# Number of epochs
|
||||||
num_epochs = 25000
|
num_epochs = 20000
|
||||||
|
|
||||||
# Number of epochs to skip between every loss report
|
# Number of epochs to skip between every loss report
|
||||||
loss_report_epoch_skip = 200
|
loss_report_epoch_skip = 500
|
||||||
|
|
||||||
model = ConstellationNet(order=order)
|
# File in which the trained model is saved
|
||||||
|
output_file = 'output/constellation-net.tc'
|
||||||
|
|
||||||
|
print('Starting training with {} epochs\n'.format(num_epochs))
|
||||||
|
|
||||||
|
model = constellation.ConstellationNet(order=order)
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
classes_dataset = torch.randint(0, order, (epoch_size,))
|
classes_dataset = util.get_random_messages(epoch_size, order)
|
||||||
onehot_dataset = torch.nn.functional.one_hot(
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||||
classes_dataset,
|
|
||||||
num_classes=order
|
|
||||||
).float()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
predictions = model(onehot_dataset)
|
predictions = model(onehot_dataset)
|
||||||
|
@ -40,8 +43,29 @@ for epoch in range(num_epochs):
|
||||||
print('Loss is {}'.format(running_loss))
|
print('Loss is {}'.format(running_loss))
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
# Test the model with class 1
|
print('\nFinished training\n')
|
||||||
print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float()))
|
|
||||||
|
|
||||||
# Test the model with class 2
|
# Print some examples of reconstruction
|
||||||
print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float()))
|
with torch.no_grad():
|
||||||
|
num_examples = 5
|
||||||
|
|
||||||
|
classes_example = util.get_random_messages(num_examples, order)
|
||||||
|
onehot_example = util.messages_to_onehot(classes_example, order)
|
||||||
|
raw_output = model(onehot_example)
|
||||||
|
raw_output.required_grad = False
|
||||||
|
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
|
||||||
|
|
||||||
|
print('Reconstruction examples:')
|
||||||
|
print('Input vector\t\t\tOutput vector after softmax')
|
||||||
|
|
||||||
|
for example_index in range(num_examples):
|
||||||
|
print('{}\t\t{}'.format(
|
||||||
|
onehot_example[example_index].tolist(),
|
||||||
|
'[{}]'.format(', '.join(
|
||||||
|
'{:.5f}'.format(x)
|
||||||
|
for x in reconstructed_example[example_index].tolist()
|
||||||
|
))
|
||||||
|
))
|
||||||
|
|
||||||
|
print('\nSaving model as {}'.format(output_file))
|
||||||
|
torch.save(model.state_dict(), output_file)
|
||||||
|
|
Loading…
Reference in New Issue