Balance training examples

This commit is contained in:
Mattéo Delabre 2019-12-15 00:59:10 -05:00
parent 99c96162c0
commit 6ea0e653c1
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 4 additions and 3 deletions

View File

@ -6,10 +6,10 @@ import torch
order = 4 order = 4
# Number of training examples in an epoch # Number of training examples in an epoch
epoch_size = 10000 epoch_size_multiple = 8
# Number of epochs # Number of epochs
num_epochs = 20000 num_epochs = 5000
# Number of epochs to skip between every loss report # Number of epochs to skip between every loss report
loss_report_epoch_skip = 500 loss_report_epoch_skip = 500
@ -32,9 +32,10 @@ criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.Adam(model.parameters())
running_loss = 0 running_loss = 0
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
for epoch in range(num_epochs): for epoch in range(num_epochs):
classes_dataset = util.get_random_messages(epoch_size, order) classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order) onehot_dataset = util.messages_to_onehot(classes_dataset, order)
optimizer.zero_grad() optimizer.zero_grad()