constellationnet/train.py

48 lines
1.3 KiB
Python
Raw Normal View History

2019-12-13 17:11:09 +00:00
from ConstellationNet import ConstellationNet
import torch
# Number of symbols to learn
order = 4
# Number of training examples in an epoch
epoch_size = 10000
# Number of epochs
num_epochs = 25000
# Number of epochs to skip between every loss report
loss_report_epoch_skip = 200
model = ConstellationNet(order=order)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
running_loss = 0
for epoch in range(num_epochs):
classes_dataset = torch.randint(0, order, (epoch_size,))
onehot_dataset = torch.nn.functional.one_hot(
classes_dataset,
num_classes=order
).float()
optimizer.zero_grad()
predictions = model(onehot_dataset)
loss = criterion(predictions, classes_dataset)
loss.backward()
optimizer.step()
# Report loss
running_loss += loss.item()
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
print('Loss is {}'.format(running_loss))
running_loss = 0
# Test the model with class 1
print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float()))
# Test the model with class 2
print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float()))