Show final loss after training

This commit is contained in:
Mattéo Delabre 2019-12-16 10:20:01 -05:00
parent 1d39184036
commit fb2518b321
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 12 additions and 0 deletions

View File

@ -105,6 +105,18 @@ while total_change >= 1e-4:
batch += 1
model.eval()
# Calcul de la perte finale
with torch.no_grad():
classes_ordered = torch.arange(order).repeat(2048)
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
predictions = model(onehot_dataset)
final_loss = criterion(predictions, classes_dataset)
print('\nFinished training')
print('Final loss is {}'.format(final_loss))
print('Saving model as {}'.format(output_file))
torch.save(model.state_dict(), output_file)