Balance training examples
This commit is contained in:
parent
99c96162c0
commit
6ea0e653c1
7
train.py
7
train.py
|
@ -6,10 +6,10 @@ import torch
|
|||
order = 4
|
||||
|
||||
# Number of training examples in an epoch
|
||||
epoch_size = 10000
|
||||
epoch_size_multiple = 8
|
||||
|
||||
# Number of epochs
|
||||
num_epochs = 20000
|
||||
num_epochs = 5000
|
||||
|
||||
# Number of epochs to skip between every loss report
|
||||
loss_report_epoch_skip = 500
|
||||
|
@ -32,9 +32,10 @@ criterion = torch.nn.CrossEntropyLoss()
|
|||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
running_loss = 0
|
||||
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
classes_dataset = util.get_random_messages(epoch_size, order)
|
||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue