Balance training examples
This commit is contained in:
		
							parent
							
								
									99c96162c0
								
							
						
					
					
						commit
						6ea0e653c1
					
				
							
								
								
									
										7
									
								
								train.py
								
								
								
								
							
							
						
						
									
										7
									
								
								train.py
								
								
								
								
							|  | @ -6,10 +6,10 @@ import torch | |||
| order = 4 | ||||
| 
 | ||||
| # Number of training examples in an epoch | ||||
| epoch_size = 10000 | ||||
| epoch_size_multiple = 8 | ||||
| 
 | ||||
| # Number of epochs | ||||
| num_epochs = 20000 | ||||
| num_epochs = 5000 | ||||
| 
 | ||||
| # Number of epochs to skip between every loss report | ||||
| loss_report_epoch_skip = 500 | ||||
|  | @ -32,9 +32,10 @@ criterion = torch.nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(model.parameters()) | ||||
| 
 | ||||
| running_loss = 0 | ||||
| classes_ordered = torch.arange(order).repeat(epoch_size_multiple) | ||||
| 
 | ||||
| for epoch in range(num_epochs): | ||||
|     classes_dataset = util.get_random_messages(epoch_size, order) | ||||
|     classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] | ||||
|     onehot_dataset = util.messages_to_onehot(classes_dataset, order) | ||||
| 
 | ||||
|     optimizer.zero_grad() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue