Fix plot scaling

This commit is contained in:
Mattéo Delabre 2019-12-15 00:03:02 -05:00
parent 8f6363ee21
commit 197a01e993
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 25 additions and 21 deletions

46
plot.py
View File

@ -14,13 +14,36 @@ input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions
color_map = matplotlib.cm.Set1
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Extract encoding
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
# Setup plot
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data
extent = 1.5
extent = max(
abs(encoded_vectors.min()),
abs(encoded_vectors.max())
) * 1.05
ax.set_xlim(-extent, extent)
ax.set_ylim(-extent, extent)
@ -68,21 +91,10 @@ ax.annotate(
ax.grid()
# Restore model from file
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
channel_model=constellation.GaussianChannel()
)
model.load_state_dict(torch.load(input_file))
model.eval()
# Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
step = 0.01
step = 0.001 * extent
grid_range = torch.arange(-extent, extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
@ -97,14 +109,6 @@ ax.imshow(
)
# Plot encoded vectors
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
ax.scatter(
*zip(*encoded_vectors.tolist()),
zorder=10,