Fix variable naming clash
This commit is contained in:
parent
fb2518b321
commit
af59421dc0
10
train.py
10
train.py
|
@ -47,7 +47,7 @@ running_loss = 0
|
||||||
classes_ordered = torch.arange(order).repeat(batch_size)
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||||
|
|
||||||
# Constellation from the previous training batch
|
# Constellation from the previous training batch
|
||||||
prev_constellation = model.get_constellation()
|
prev_constel = model.get_constellation()
|
||||||
total_change = float('inf')
|
total_change = float('inf')
|
||||||
|
|
||||||
# Optimizer settings
|
# Optimizer settings
|
||||||
|
@ -79,9 +79,9 @@ while total_change >= 1e-4:
|
||||||
|
|
||||||
# Check for convergence
|
# Check for convergence
|
||||||
model.eval()
|
model.eval()
|
||||||
constellation = model.get_constellation()
|
cur_constel = model.get_constellation()
|
||||||
total_change = (constellation - prev_constellation).norm(dim=1).sum()
|
total_change = (cur_constel - prev_constel).norm(dim=1).sum()
|
||||||
prev_constellation = constellation
|
prev_constel = cur_constel
|
||||||
|
|
||||||
# Report loss
|
# Report loss
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
@ -96,7 +96,7 @@ while total_change >= 1e-4:
|
||||||
# Update figure with current encoding
|
# Update figure with current encoding
|
||||||
ax.clear()
|
ax.clear()
|
||||||
util.plot_constellation(
|
util.plot_constellation(
|
||||||
ax, constellation,
|
ax, cur_constel,
|
||||||
model.channel, model.decoder,
|
model.channel, model.decoder,
|
||||||
noise_samples=0
|
noise_samples=0
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue