|
|
@ -109,27 +109,6 @@ while total_change >= 1e-4: |
|
|
|
|
|
|
|
batch += 1 |
|
|
|
|
|
|
|
print('\nFinished training\n') |
|
|
|
|
|
|
|
# Print some examples of reconstruction |
|
|
|
model.eval() |
|
|
|
print('Reconstruction examples:') |
|
|
|
print('Input vector\t\t\tOutput vector after softmax') |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
onehot_example = util.messages_to_onehot(torch.arange(0, order), order) |
|
|
|
raw_output = model(onehot_example) |
|
|
|
raw_output.required_grad = False |
|
|
|
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) |
|
|
|
|
|
|
|
for index in range(order): |
|
|
|
print('{}\t\t{}'.format( |
|
|
|
onehot_example[index].tolist(), |
|
|
|
'[{}]'.format(', '.join( |
|
|
|
'{:.5f}'.format(x) |
|
|
|
for x in reconstructed_example[index].tolist() |
|
|
|
)) |
|
|
|
)) |
|
|
|
|
|
|
|
print('\nSaving model as {}'.format(output_file)) |
|
|
|
print('\nFinished training') |
|
|
|
print('Saving model as {}'.format(output_file)) |
|
|
|
torch.save(model.state_dict(), output_file) |
|
|
|