From af59421dc043c785db122b30981bf4489c5f120b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Mon, 16 Dec 2019 10:31:16 -0500 Subject: [PATCH] Fix variable naming clash --- train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 011d3e4..64b6f7d 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ running_loss = 0 classes_ordered = torch.arange(order).repeat(batch_size) # Constellation from the previous training batch -prev_constellation = model.get_constellation() +prev_constel = model.get_constellation() total_change = float('inf') # Optimizer settings @@ -79,9 +79,9 @@ while total_change >= 1e-4: # Check for convergence model.eval() - constellation = model.get_constellation() - total_change = (constellation - prev_constellation).norm(dim=1).sum() - prev_constellation = constellation + cur_constel = model.get_constellation() + total_change = (cur_constel - prev_constel).norm(dim=1).sum() + prev_constel = cur_constel # Report loss running_loss += loss.item() @@ -96,7 +96,7 @@ while total_change >= 1e-4: # Update figure with current encoding ax.clear() util.plot_constellation( - ax, constellation, + ax, cur_constel, model.channel, model.decoder, noise_samples=0 )