Fix plot scaling

This commit is contained in:
Mattéo Delabre 2019-12-15 00:03:02 -05:00
parent 8f6363ee21
commit 197a01e993
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 25 additions and 21 deletions

46
plot.py
View File

@ -14,13 +14,36 @@ input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions # Color map used for decision regions
color_map = matplotlib.cm.Set1 color_map = matplotlib.cm.Set1
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Extract encoding
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
# Setup plot # Setup plot
fig = pyplot.figure() fig = pyplot.figure()
ax = SubplotZero(fig, 111) ax = SubplotZero(fig, 111)
fig.add_subplot(ax) fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data # Extend axes symmetrically around zero so that they fit data
extent = 1.5 extent = max(
abs(encoded_vectors.min()),
abs(encoded_vectors.max())
) * 1.05
ax.set_xlim(-extent, extent) ax.set_xlim(-extent, extent)
ax.set_ylim(-extent, extent) ax.set_ylim(-extent, extent)
@ -68,21 +91,10 @@ ax.annotate(
ax.grid() ax.grid()
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Plot decision regions # Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
step = 0.01 step = 0.001 * extent
grid_range = torch.arange(-extent, extent, step) grid_range = torch.arange(-extent, extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range) grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
@ -97,14 +109,6 @@ ax.imshow(
) )
# Plot encoded vectors # Plot encoded vectors
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
ax.scatter( ax.scatter(
*zip(*encoded_vectors.tolist()), *zip(*encoded_vectors.tolist()),
zorder=10, zorder=10,