Fix plot scaling
This commit is contained in:
parent
8f6363ee21
commit
197a01e993
46
plot.py
46
plot.py
|
@ -14,13 +14,36 @@ input_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
# Color map used for decision regions
|
# Color map used for decision regions
|
||||||
color_map = matplotlib.cm.Set1
|
color_map = matplotlib.cm.Set1
|
||||||
|
|
||||||
|
# Restore model from file
|
||||||
|
model = constellation.ConstellationNet(
|
||||||
|
order=order,
|
||||||
|
encoder_layers_sizes=(4,),
|
||||||
|
decoder_layers_sizes=(4,),
|
||||||
|
channel_model=constellation.GaussianChannel()
|
||||||
|
)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(input_file))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Extract encoding
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_vectors = model.encoder(
|
||||||
|
util.messages_to_onehot(
|
||||||
|
torch.arange(0, order),
|
||||||
|
order
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Setup plot
|
# Setup plot
|
||||||
fig = pyplot.figure()
|
fig = pyplot.figure()
|
||||||
ax = SubplotZero(fig, 111)
|
ax = SubplotZero(fig, 111)
|
||||||
fig.add_subplot(ax)
|
fig.add_subplot(ax)
|
||||||
|
|
||||||
# Extend axes symmetrically around zero so that they fit data
|
# Extend axes symmetrically around zero so that they fit data
|
||||||
extent = 1.5
|
extent = max(
|
||||||
|
abs(encoded_vectors.min()),
|
||||||
|
abs(encoded_vectors.max())
|
||||||
|
) * 1.05
|
||||||
ax.set_xlim(-extent, extent)
|
ax.set_xlim(-extent, extent)
|
||||||
ax.set_ylim(-extent, extent)
|
ax.set_ylim(-extent, extent)
|
||||||
|
|
||||||
|
@ -68,21 +91,10 @@ ax.annotate(
|
||||||
|
|
||||||
ax.grid()
|
ax.grid()
|
||||||
|
|
||||||
# Restore model from file
|
|
||||||
model = constellation.ConstellationNet(
|
|
||||||
order=order,
|
|
||||||
encoder_layers_sizes=(4,),
|
|
||||||
decoder_layers_sizes=(4,),
|
|
||||||
channel_model=constellation.GaussianChannel()
|
|
||||||
)
|
|
||||||
|
|
||||||
model.load_state_dict(torch.load(input_file))
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Plot decision regions
|
# Plot decision regions
|
||||||
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
||||||
|
|
||||||
step = 0.01
|
step = 0.001 * extent
|
||||||
grid_range = torch.arange(-extent, extent, step)
|
grid_range = torch.arange(-extent, extent, step)
|
||||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||||
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
||||||
|
@ -97,14 +109,6 @@ ax.imshow(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot encoded vectors
|
# Plot encoded vectors
|
||||||
with torch.no_grad():
|
|
||||||
encoded_vectors = model.encoder(
|
|
||||||
util.messages_to_onehot(
|
|
||||||
torch.arange(0, order),
|
|
||||||
order
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
*zip(*encoded_vectors.tolist()),
|
*zip(*encoded_vectors.tolist()),
|
||||||
zorder=10,
|
zorder=10,
|
||||||
|
|
Loading…
Reference in New Issue