Share plotting code between plot.py and train.py

This commit is contained in:
Mattéo Delabre 2019-12-15 20:05:20 -05:00
parent 34cb3b863b
commit 3638749998
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
3 changed files with 73 additions and 99 deletions

View File

@ -37,7 +37,27 @@ def messages_to_onehot(messages, order):
return torch.nn.functional.one_hot(messages, num_classes=order).float() return torch.nn.functional.one_hot(messages, num_classes=order).float()
def plot_constellation(ax, constellation): def plot_constellation(
ax,
constellation,
channel,
decoder,
grid_step=0.05,
noise_samples=1000
):
"""
Plot a constellation with its decoder and channel noise.
:param ax: Matplotlib axes to plot on.
:param constellation: Constellation to plot.
:param channel: Channel model to use for generating noise.
:param decoder: Decoder function able to map the constellation points back
to the original messages.
:param grid_step: Grid step used for drawing the decision regions,
expressed as percentage of the total plot width (or equivalently height).
Lower steps makes more precise grids but takes more time to compute.
:param noise_samples: Number of noisy points to sample and plot.
"""
ax.grid() ax.grid()
order = len(constellation) order = len(constellation)
@ -76,19 +96,54 @@ def plot_constellation(ax, constellation):
va='center', ha='center' va='center', ha='center'
) )
# Plot decision regions
regions_extent = 2 * axis_extent
step = grid_step * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow(
grid_images,
extent=(
-regions_extent, regions_extent,
-regions_extent, regions_extent
),
aspect='auto',
origin='lower',
cmap=color_map,
norm=color_norm,
alpha=0.2
)
# Plot constellation
ax.scatter( ax.scatter(
*zip(*constellation.tolist()), *zip(*constellation.tolist()),
zorder=10, zorder=10,
s=60, s=60,
c=range(len(constellation)), c=range(order),
edgecolor='black', edgecolor='black',
cmap=color_map, cmap=color_map,
norm=color_norm, norm=color_norm,
) )
# Plot center # Plot constellation center
center = constellation.sum(dim=0) / order center = constellation.sum(dim=0) / order
ax.scatter( ax.scatter(
center[0], center[1], center[0], center[1],
marker='X', marker='X',
) )
# Plot channel noise
noisy_vectors = channel(constellation.repeat(noise_samples, 1))
ax.scatter(
*zip(*noisy_vectors.tolist()),
marker='.',
s=5,
c=list(range(order)) * noise_samples,
cmap=color_map,
norm=color_norm,
alpha=0.7,
zorder=8
)

95
plot.py
View File

@ -17,105 +17,24 @@ color_map = matplotlib.cm.Dark2
# Restore model from file # Restore model from file
model = constellation.ConstellationNet( model = constellation.ConstellationNet(
order=order, order=order,
encoder_layers_sizes=(4,), encoder_layers_sizes=(8,),
decoder_layers_sizes=(4,), decoder_layers_sizes=(8,),
channel_model=constellation.GaussianChannel() channel_model=constellation.GaussianChannel()
) )
model.load_state_dict(torch.load(input_file)) model.load_state_dict(torch.load(input_file))
model.eval() model.eval()
# Extract encoding
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
)
# Setup plot # Setup plot
fig = pyplot.figure() fig = pyplot.figure()
ax = SubplotZero(fig, 111) ax = SubplotZero(fig, 111)
fig.add_subplot(ax) fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data constellation = model.get_constellation()
axis_extent = max( util.plot_constellation(
abs(encoded_vectors.min()), ax, constellation,
abs(encoded_vectors.max()) model.channel, model.decoder,
) * 1.05 grid_step=0.001, noise_samples=2500
ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent)
# Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']:
axis = ax.axis[direction]
axis.set_visible(True)
axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Add axis names
ax.annotate(
'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points',
va='center', ha='right'
)
ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points',
va='center', ha='center'
)
ax.grid()
# Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
regions_extent = 2 * axis_extent
step = 0.001 * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow(
grid_images,
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
aspect='auto',
origin='lower',
cmap=color_map,
norm=color_norm,
alpha=0.1
)
# Plot encoded vectors
ax.scatter(
*zip(*encoded_vectors.tolist()),
zorder=10,
s=60,
c=range(len(encoded_vectors)),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Plot noise
noisy_count = 1000
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
ax.scatter(
*zip(*noisy_vectors.tolist()),
marker='.',
s=5,
c=list(range(len(encoded_vectors))) * noisy_count,
cmap=color_map,
norm=color_norm,
alpha=0.7,
zorder=8
) )
pyplot.show() pyplot.show()

View File

@ -2,9 +2,7 @@ import constellation
from constellation import util from constellation import util
import torch import torch
from matplotlib import pyplot from matplotlib import pyplot
import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero from mpl_toolkits.axisartist.axislines import SubplotZero
import time
# Number of symbols to learn # Number of symbols to learn
order = 4 order = 4
@ -95,16 +93,18 @@ while True:
running_loss += loss.item() running_loss += loss.item()
if batch % loss_report_batch_skip == loss_report_batch_skip - 1: if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
ax.clear()
util.plot_constellation(ax, constellation)
fig.canvas.draw()
pyplot.pause(1e-17)
time.sleep(0.1)
print('Batch #{} (size {})'.format(batch + 1, batch_size)) print('Batch #{} (size {})'.format(batch + 1, batch_size))
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
print('\tChange is {}\n'.format(total_change)) print('\tChange is {}\n'.format(total_change))
ax.clear()
util.plot_constellation(
ax, constellation,
model.channel, model.decoder
)
fig.canvas.draw()
pyplot.pause(1e-17)
running_loss = 0 running_loss = 0
batch += 1 batch += 1