Fix variable naming clash
This commit is contained in:
parent
fb2518b321
commit
af59421dc0
10
train.py
10
train.py
|
@ -47,7 +47,7 @@ running_loss = 0
|
|||
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||
|
||||
# Constellation from the previous training batch
|
||||
prev_constellation = model.get_constellation()
|
||||
prev_constel = model.get_constellation()
|
||||
total_change = float('inf')
|
||||
|
||||
# Optimizer settings
|
||||
|
@ -79,9 +79,9 @@ while total_change >= 1e-4:
|
|||
|
||||
# Check for convergence
|
||||
model.eval()
|
||||
constellation = model.get_constellation()
|
||||
total_change = (constellation - prev_constellation).norm(dim=1).sum()
|
||||
prev_constellation = constellation
|
||||
cur_constel = model.get_constellation()
|
||||
total_change = (cur_constel - prev_constel).norm(dim=1).sum()
|
||||
prev_constel = cur_constel
|
||||
|
||||
# Report loss
|
||||
running_loss += loss.item()
|
||||
|
@ -96,7 +96,7 @@ while total_change >= 1e-4:
|
|||
# Update figure with current encoding
|
||||
ax.clear()
|
||||
util.plot_constellation(
|
||||
ax, constellation,
|
||||
ax, cur_constel,
|
||||
model.channel, model.decoder,
|
||||
noise_samples=0
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue