48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
from ConstellationNet import ConstellationNet
|
|
import torch
|
|
|
|
# Number of symbols to learn
|
|
order = 4
|
|
|
|
# Number of training examples in an epoch
|
|
epoch_size = 10000
|
|
|
|
# Number of epochs
|
|
num_epochs = 25000
|
|
|
|
# Number of epochs to skip between every loss report
|
|
loss_report_epoch_skip = 200
|
|
|
|
model = ConstellationNet(order=order)
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
running_loss = 0
|
|
|
|
for epoch in range(num_epochs):
|
|
classes_dataset = torch.randint(0, order, (epoch_size,))
|
|
onehot_dataset = torch.nn.functional.one_hot(
|
|
classes_dataset,
|
|
num_classes=order
|
|
).float()
|
|
|
|
optimizer.zero_grad()
|
|
predictions = model(onehot_dataset)
|
|
loss = criterion(predictions, classes_dataset)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Report loss
|
|
running_loss += loss.item()
|
|
|
|
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
|
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
|
print('Loss is {}'.format(running_loss))
|
|
running_loss = 0
|
|
|
|
# Test the model with class 1
|
|
print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float()))
|
|
|
|
# Test the model with class 2
|
|
print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float()))
|