diff --git a/train.py b/train.py index edb1ad2..abf394f 100644 --- a/train.py +++ b/train.py @@ -6,10 +6,10 @@ import torch order = 4 # Number of training examples in an epoch -epoch_size = 10000 +epoch_size_multiple = 8 # Number of epochs -num_epochs = 20000 +num_epochs = 5000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 500 @@ -32,9 +32,10 @@ criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 +classes_ordered = torch.arange(order).repeat(epoch_size_multiple) for epoch in range(num_epochs): - classes_dataset = util.get_random_messages(epoch_size, order) + classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) optimizer.zero_grad()