|
|
@ -43,10 +43,6 @@ batch = 0 |
|
|
|
# Accumulated loss for last batches |
|
|
|
running_loss = 0 |
|
|
|
|
|
|
|
# True in the first training phase where small batches are used, and false in |
|
|
|
# the second phase where point positions are refined using large batches |
|
|
|
is_coarse_optim = True |
|
|
|
|
|
|
|
# List of training examples (not shuffled) |
|
|
|
classes_ordered = torch.arange(order).repeat(batch_size) |
|
|
|
|
|
|
@ -84,7 +80,7 @@ while total_change >= 1e-4: |
|
|
|
# Check for convergence |
|
|
|
model.eval() |
|
|
|
constellation = model.get_constellation() |
|
|
|
total_change = (constellation - prev_constellation).abs().sum() |
|
|
|
total_change = (constellation - prev_constellation).norm(dim=1).sum() |
|
|
|
prev_constellation = constellation |
|
|
|
|
|
|
|
# Report loss |
|
|
|