Plot decision regions from the decoder
This commit is contained in:
parent
3b40e27070
commit
8f6363ee21
77
plot.py
77
plot.py
|
@ -2,6 +2,7 @@ import constellation
|
|||
from constellation import util
|
||||
import torch
|
||||
from matplotlib import pyplot
|
||||
import matplotlib
|
||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||
|
||||
# Number learned symbols
|
||||
|
@ -10,36 +11,16 @@ order = 4
|
|||
# File in which the trained model is saved
|
||||
input_file = 'output/constellation-order-{}.pth'.format(order)
|
||||
|
||||
# Restore model from file
|
||||
model = constellation.ConstellationNet(
|
||||
order=order,
|
||||
encoder_layers_sizes=(4,),
|
||||
decoder_layers_sizes=(4,),
|
||||
channel_model=constellation.GaussianChannel()
|
||||
)
|
||||
|
||||
model.load_state_dict(torch.load(input_file))
|
||||
model.eval()
|
||||
|
||||
# Compute encoded vectors
|
||||
with torch.no_grad():
|
||||
encoded_vectors = model.encoder(
|
||||
util.messages_to_onehot(
|
||||
torch.arange(0, order),
|
||||
order
|
||||
)
|
||||
)
|
||||
# Color map used for decision regions
|
||||
color_map = matplotlib.cm.Set1
|
||||
|
||||
# Setup plot
|
||||
fig = pyplot.figure()
|
||||
ax = SubplotZero(fig, 111)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
# Extend axes symmetrically around zero so that they fit data
|
||||
extent = max(
|
||||
abs(encoded_vectors.min()),
|
||||
abs(encoded_vectors.max())
|
||||
) * 1.05
|
||||
|
||||
extent = 1.5
|
||||
ax.set_xlim(-extent, extent)
|
||||
ax.set_ylim(-extent, extent)
|
||||
|
||||
|
@ -87,10 +68,52 @@ ax.annotate(
|
|||
|
||||
ax.grid()
|
||||
|
||||
# Plot encoded vectors
|
||||
ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10)
|
||||
# Restore model from file
|
||||
model = constellation.ConstellationNet(
|
||||
order=order,
|
||||
encoder_layers_sizes=(4,),
|
||||
decoder_layers_sizes=(4,),
|
||||
channel_model=constellation.GaussianChannel()
|
||||
)
|
||||
|
||||
model.load_state_dict(torch.load(input_file))
|
||||
model.eval()
|
||||
|
||||
# Plot decision regions
|
||||
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
||||
|
||||
step = 0.01
|
||||
grid_range = torch.arange(-extent, extent, step)
|
||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
||||
|
||||
ax.imshow(
|
||||
grid_images, extent=(-extent, extent, -extent, extent),
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
alpha=0.15
|
||||
)
|
||||
|
||||
# Plot encoded vectors
|
||||
with torch.no_grad():
|
||||
encoded_vectors = model.encoder(
|
||||
util.messages_to_onehot(
|
||||
torch.arange(0, order),
|
||||
order
|
||||
)
|
||||
)
|
||||
|
||||
ax.scatter(
|
||||
*zip(*encoded_vectors.tolist()),
|
||||
zorder=10,
|
||||
c=range(len(encoded_vectors)),
|
||||
edgecolor='black',
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
)
|
||||
|
||||
# Add index label for each vector
|
||||
for row in range(order):
|
||||
ax.annotate(
|
||||
row + 1, encoded_vectors[row],
|
||||
|
|
Loading…
Reference in New Issue