Show final loss after training
This commit is contained in:
		
							parent
							
								
									1d39184036
								
							
						
					
					
						commit
						fb2518b321
					
				
							
								
								
									
										12
									
								
								train.py
								
								
								
								
							
							
						
						
									
										12
									
								
								train.py
								
								
								
								
							|  | @ -105,6 +105,18 @@ while total_change >= 1e-4: | ||||||
| 
 | 
 | ||||||
|     batch += 1 |     batch += 1 | ||||||
| 
 | 
 | ||||||
|  | model.eval() | ||||||
|  | 
 | ||||||
|  | # Calcul de la perte finale | ||||||
|  | with torch.no_grad(): | ||||||
|  |     classes_ordered = torch.arange(order).repeat(2048) | ||||||
|  |     classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] | ||||||
|  |     onehot_dataset = util.messages_to_onehot(classes_dataset, order) | ||||||
|  | 
 | ||||||
|  |     predictions = model(onehot_dataset) | ||||||
|  |     final_loss = criterion(predictions, classes_dataset) | ||||||
|  | 
 | ||||||
| print('\nFinished training') | print('\nFinished training') | ||||||
|  | print('Final loss is {}'.format(final_loss)) | ||||||
| print('Saving model as {}'.format(output_file)) | print('Saving model as {}'.format(output_file)) | ||||||
| torch.save(model.state_dict(), output_file) | torch.save(model.state_dict(), output_file) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue