diff --git a/train.py b/train.py index 9e97551..d2a4b38 100644 --- a/train.py +++ b/train.py @@ -43,10 +43,6 @@ batch = 0 # Accumulated loss for last batches running_loss = 0 -# True in the first training phase where small batches are used, and false in -# the second phase where point positions are refined using large batches -is_coarse_optim = True - # List of training examples (not shuffled) classes_ordered = torch.arange(order).repeat(batch_size) @@ -84,7 +80,7 @@ while total_change >= 1e-4: # Check for convergence model.eval() constellation = model.get_constellation() - total_change = (constellation - prev_constellation).abs().sum() + total_change = (constellation - prev_constellation).norm(dim=1).sum() prev_constellation = constellation # Report loss