Remove reconstruction examples after training
This commit is contained in:
parent
4bf0c0f363
commit
2c4786bdba
25
train.py
25
train.py
|
@ -109,27 +109,6 @@ while total_change >= 1e-4:
|
||||||
|
|
||||||
batch += 1
|
batch += 1
|
||||||
|
|
||||||
print('\nFinished training\n')
|
print('\nFinished training')
|
||||||
|
print('Saving model as {}'.format(output_file))
|
||||||
# Print some examples of reconstruction
|
|
||||||
model.eval()
|
|
||||||
print('Reconstruction examples:')
|
|
||||||
print('Input vector\t\t\tOutput vector after softmax')
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
|
|
||||||
raw_output = model(onehot_example)
|
|
||||||
raw_output.required_grad = False
|
|
||||||
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
|
|
||||||
|
|
||||||
for index in range(order):
|
|
||||||
print('{}\t\t{}'.format(
|
|
||||||
onehot_example[index].tolist(),
|
|
||||||
'[{}]'.format(', '.join(
|
|
||||||
'{:.5f}'.format(x)
|
|
||||||
for x in reconstructed_example[index].tolist()
|
|
||||||
))
|
|
||||||
))
|
|
||||||
|
|
||||||
print('\nSaving model as {}'.format(output_file))
|
|
||||||
torch.save(model.state_dict(), output_file)
|
torch.save(model.state_dict(), output_file)
|
||||||
|
|
Loading…
Reference in New Issue