Show final loss after training
This commit is contained in:
parent
1d39184036
commit
fb2518b321
12
train.py
12
train.py
|
@ -105,6 +105,18 @@ while total_change >= 1e-4:
|
||||||
|
|
||||||
batch += 1
|
batch += 1
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Calcul de la perte finale
|
||||||
|
with torch.no_grad():
|
||||||
|
classes_ordered = torch.arange(order).repeat(2048)
|
||||||
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||||
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||||
|
|
||||||
|
predictions = model(onehot_dataset)
|
||||||
|
final_loss = criterion(predictions, classes_dataset)
|
||||||
|
|
||||||
print('\nFinished training')
|
print('\nFinished training')
|
||||||
|
print('Final loss is {}'.format(final_loss))
|
||||||
print('Saving model as {}'.format(output_file))
|
print('Saving model as {}'.format(output_file))
|
||||||
torch.save(model.state_dict(), output_file)
|
torch.save(model.state_dict(), output_file)
|
||||||
|
|
Loading…
Reference in New Issue