diff --git a/train.py b/train.py index abf394f..838f70d 100644 --- a/train.py +++ b/train.py @@ -5,11 +5,8 @@ import torch # Number of symbols to learn order = 4 -# Number of training examples in an epoch -epoch_size_multiple = 8 - # Number of epochs -num_epochs = 5000 +num_epochs = 10000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 500 @@ -32,9 +29,15 @@ criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 -classes_ordered = torch.arange(order).repeat(epoch_size_multiple) +prev_epoch_size_multiple = 0 for epoch in range(num_epochs): + epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048 + + if epoch_size_multiple != prev_epoch_size_multiple: + classes_ordered = torch.arange(order).repeat(epoch_size_multiple) + prev_epoch_size_multiple = epoch_size_multiple + classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) @@ -49,7 +52,7 @@ for epoch in range(num_epochs): if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: print('Epoch {}/{}'.format(epoch + 1, num_epochs)) - print('Loss is {}'.format(running_loss)) + print('Loss is {}'.format(running_loss / loss_report_epoch_skip)) running_loss = 0 print('\nFinished training\n')