Browse Source

Fix variable naming clash

master
Mattéo Delabre 4 years ago
parent
commit
af59421dc0
Signed by: matteo GPG Key ID: AE3FBD02DC583ABB
  1. 10
      train.py

10
train.py

@ -47,7 +47,7 @@ running_loss = 0
classes_ordered = torch.arange(order).repeat(batch_size)
# Constellation from the previous training batch
prev_constellation = model.get_constellation()
prev_constel = model.get_constellation()
total_change = float('inf')
# Optimizer settings
@ -79,9 +79,9 @@ while total_change >= 1e-4:
# Check for convergence
model.eval()
constellation = model.get_constellation()
total_change = (constellation - prev_constellation).norm(dim=1).sum()
prev_constellation = constellation
cur_constel = model.get_constellation()
total_change = (cur_constel - prev_constel).norm(dim=1).sum()
prev_constel = cur_constel
# Report loss
running_loss += loss.item()
@ -96,7 +96,7 @@ while total_change >= 1e-4:
# Update figure with current encoding
ax.clear()
util.plot_constellation(
ax, constellation,
ax, cur_constel,
model.channel, model.decoder,
noise_samples=0
)

Loading…
Cancel
Save