Increase batch size halfway through training

This commit is contained in:
Mattéo Delabre 2019-12-15 09:42:48 -05:00
parent 3365603614
commit 989a51b72e
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 9 additions and 6 deletions

View File

@ -5,11 +5,8 @@ import torch
# Number of symbols to learn # Number of symbols to learn
order = 4 order = 4
# Number of training examples in an epoch
epoch_size_multiple = 8
# Number of epochs # Number of epochs
num_epochs = 5000 num_epochs = 10000
# Number of epochs to skip between every loss report # Number of epochs to skip between every loss report
loss_report_epoch_skip = 500 loss_report_epoch_skip = 500
@ -32,9 +29,15 @@ criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.Adam(model.parameters())
running_loss = 0 running_loss = 0
classes_ordered = torch.arange(order).repeat(epoch_size_multiple) prev_epoch_size_multiple = 0
for epoch in range(num_epochs): for epoch in range(num_epochs):
epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048
if epoch_size_multiple != prev_epoch_size_multiple:
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
prev_epoch_size_multiple = epoch_size_multiple
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order) onehot_dataset = util.messages_to_onehot(classes_dataset, order)
@ -49,7 +52,7 @@ for epoch in range(num_epochs):
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('Epoch {}/{}'.format(epoch + 1, num_epochs))
print('Loss is {}'.format(running_loss)) print('Loss is {}'.format(running_loss / loss_report_epoch_skip))
running_loss = 0 running_loss = 0
print('\nFinished training\n') print('\nFinished training\n')