diff --git a/train.py b/train.py index d2a4b38..011d3e4 100644 --- a/train.py +++ b/train.py @@ -105,6 +105,18 @@ while total_change >= 1e-4: batch += 1 +model.eval() + +# Calcul de la perte finale +with torch.no_grad(): + classes_ordered = torch.arange(order).repeat(2048) + classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] + onehot_dataset = util.messages_to_onehot(classes_dataset, order) + + predictions = model(onehot_dataset) + final_loss = criterion(predictions, classes_dataset) + print('\nFinished training') +print('Final loss is {}'.format(final_loss)) print('Saving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)