from ConstellationNet import ConstellationNet import torch # Number of symbols to learn order = 4 # Number of training examples in an epoch epoch_size = 10000 # Number of epochs num_epochs = 25000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 200 model = ConstellationNet(order=order) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 for epoch in range(num_epochs): classes_dataset = torch.randint(0, order, (epoch_size,)) onehot_dataset = torch.nn.functional.one_hot( classes_dataset, num_classes=order ).float() optimizer.zero_grad() predictions = model(onehot_dataset) loss = criterion(predictions, classes_dataset) loss.backward() optimizer.step() # Report loss running_loss += loss.item() if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('Loss is {}'.format(running_loss)) running_loss = 0 # Test the model with class 1 print(model(torch.nn.functional.one_hot(torch.tensor(0), num_classes=order).float())) # Test the model with class 2 print(model(torch.nn.functional.one_hot(torch.tensor(1), num_classes=order).float()))