Increase batch size halfway through training
This commit is contained in:
parent
3365603614
commit
989a51b72e
15
train.py
15
train.py
|
@ -5,11 +5,8 @@ import torch
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 4
|
order = 4
|
||||||
|
|
||||||
# Number of training examples in an epoch
|
|
||||||
epoch_size_multiple = 8
|
|
||||||
|
|
||||||
# Number of epochs
|
# Number of epochs
|
||||||
num_epochs = 5000
|
num_epochs = 10000
|
||||||
|
|
||||||
# Number of epochs to skip between every loss report
|
# Number of epochs to skip between every loss report
|
||||||
loss_report_epoch_skip = 500
|
loss_report_epoch_skip = 500
|
||||||
|
@ -32,9 +29,15 @@ criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
prev_epoch_size_multiple = 0
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
|
epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048
|
||||||
|
|
||||||
|
if epoch_size_multiple != prev_epoch_size_multiple:
|
||||||
|
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
||||||
|
prev_epoch_size_multiple = epoch_size_multiple
|
||||||
|
|
||||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||||
|
|
||||||
|
@ -49,7 +52,7 @@ for epoch in range(num_epochs):
|
||||||
|
|
||||||
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
||||||
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
||||||
print('Loss is {}'.format(running_loss))
|
print('Loss is {}'.format(running_loss / loss_report_epoch_skip))
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
print('\nFinished training\n')
|
print('\nFinished training\n')
|
||||||
|
|
Loading…
Reference in New Issue