Show final loss after training
This commit is contained in:
parent
1d39184036
commit
fb2518b321
12
train.py
12
train.py
|
@ -105,6 +105,18 @@ while total_change >= 1e-4:
|
|||
|
||||
batch += 1
|
||||
|
||||
model.eval()
|
||||
|
||||
# Calcul de la perte finale
|
||||
with torch.no_grad():
|
||||
classes_ordered = torch.arange(order).repeat(2048)
|
||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||
|
||||
predictions = model(onehot_dataset)
|
||||
final_loss = criterion(predictions, classes_dataset)
|
||||
|
||||
print('\nFinished training')
|
||||
print('Final loss is {}'.format(final_loss))
|
||||
print('Saving model as {}'.format(output_file))
|
||||
torch.save(model.state_dict(), output_file)
|
||||
|
|
Loading…
Reference in New Issue