|
|
@ -47,7 +47,7 @@ running_loss = 0 |
|
|
|
classes_ordered = torch.arange(order).repeat(batch_size) |
|
|
|
|
|
|
|
# Constellation from the previous training batch |
|
|
|
prev_constellation = model.get_constellation() |
|
|
|
prev_constel = model.get_constellation() |
|
|
|
total_change = float('inf') |
|
|
|
|
|
|
|
# Optimizer settings |
|
|
@ -79,9 +79,9 @@ while total_change >= 1e-4: |
|
|
|
|
|
|
|
# Check for convergence |
|
|
|
model.eval() |
|
|
|
constellation = model.get_constellation() |
|
|
|
total_change = (constellation - prev_constellation).norm(dim=1).sum() |
|
|
|
prev_constellation = constellation |
|
|
|
cur_constel = model.get_constellation() |
|
|
|
total_change = (cur_constel - prev_constel).norm(dim=1).sum() |
|
|
|
prev_constel = cur_constel |
|
|
|
|
|
|
|
# Report loss |
|
|
|
running_loss += loss.item() |
|
|
@ -96,7 +96,7 @@ while total_change >= 1e-4: |
|
|
|
# Update figure with current encoding |
|
|
|
ax.clear() |
|
|
|
util.plot_constellation( |
|
|
|
ax, constellation, |
|
|
|
ax, cur_constel, |
|
|
|
model.channel, model.decoder, |
|
|
|
noise_samples=0 |
|
|
|
) |
|
|
|