diff --git a/train.py b/train.py index 011d3e4..64b6f7d 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ running_loss = 0 classes_ordered = torch.arange(order).repeat(batch_size) # Constellation from the previous training batch -prev_constellation = model.get_constellation() +prev_constel = model.get_constellation() total_change = float('inf') # Optimizer settings @@ -79,9 +79,9 @@ while total_change >= 1e-4: # Check for convergence model.eval() - constellation = model.get_constellation() - total_change = (constellation - prev_constellation).norm(dim=1).sum() - prev_constellation = constellation + cur_constel = model.get_constellation() + total_change = (cur_constel - prev_constel).norm(dim=1).sum() + prev_constel = cur_constel # Report loss running_loss += loss.item() @@ -96,7 +96,7 @@ while total_change >= 1e-4: # Update figure with current encoding ax.clear() util.plot_constellation( - ax, constellation, + ax, cur_constel, model.channel, model.decoder, noise_samples=0 )