From 6ea0e653c1099e82aed468a75bb0c9bd3a8c2e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Sun, 15 Dec 2019 00:59:10 -0500 Subject: [PATCH] Balance training examples --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index edb1ad2..abf394f 100644 --- a/train.py +++ b/train.py @@ -6,10 +6,10 @@ import torch order = 4 # Number of training examples in an epoch -epoch_size = 10000 +epoch_size_multiple = 8 # Number of epochs -num_epochs = 20000 +num_epochs = 5000 # Number of epochs to skip between every loss report loss_report_epoch_skip = 500 @@ -32,9 +32,10 @@ criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) running_loss = 0 +classes_ordered = torch.arange(order).repeat(epoch_size_multiple) for epoch in range(num_epochs): - classes_dataset = util.get_random_messages(epoch_size, order) + classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) optimizer.zero_grad()