From fb2518b321357720340163e679bf7d49ad8ce1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Mon, 16 Dec 2019 10:20:01 -0500 Subject: [PATCH] Show final loss after training --- train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/train.py b/train.py index d2a4b38..011d3e4 100644 --- a/train.py +++ b/train.py @@ -105,6 +105,18 @@ while total_change >= 1e-4: batch += 1 +model.eval() + +# Calcul de la perte finale +with torch.no_grad(): + classes_ordered = torch.arange(order).repeat(2048) + classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] + onehot_dataset = util.messages_to_onehot(classes_dataset, order) + + predictions = model(onehot_dataset) + final_loss = criterion(predictions, classes_dataset) + print('\nFinished training') +print('Final loss is {}'.format(final_loss)) print('Saving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)