diff --git a/train.py b/train.py index 11e5ac7..9e97551 100644 --- a/train.py +++ b/train.py @@ -109,27 +109,6 @@ while total_change >= 1e-4: batch += 1 -print('\nFinished training\n') - -# Print some examples of reconstruction -model.eval() -print('Reconstruction examples:') -print('Input vector\t\t\tOutput vector after softmax') - -with torch.no_grad(): - onehot_example = util.messages_to_onehot(torch.arange(0, order), order) - raw_output = model(onehot_example) - raw_output.required_grad = False - reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1) - - for index in range(order): - print('{}\t\t{}'.format( - onehot_example[index].tolist(), - '[{}]'.format(', '.join( - '{:.5f}'.format(x) - for x in reconstructed_example[index].tolist() - )) - )) - -print('\nSaving model as {}'.format(output_file)) +print('\nFinished training') +print('Saving model as {}'.format(output_file)) torch.save(model.state_dict(), output_file)