diff --git a/train.py b/train.py index 64b6f7d..e4d0948 100644 --- a/train.py +++ b/train.py @@ -61,7 +61,7 @@ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( threshold=1e-8 ) -while total_change >= 1e-4: +while total_change >= 1e-3: # Shuffle training data and convert to one-hot encoding classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order)