From 36387499981a11885411ba298fa76fde1a8008ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Sun, 15 Dec 2019 20:05:20 -0500 Subject: [PATCH] Share plotting code between plot.py and train.py --- constellation/util.py | 61 +++++++++++++++++++++++++-- plot.py | 95 ++++--------------------------------------- train.py | 16 ++++---- 3 files changed, 73 insertions(+), 99 deletions(-) diff --git a/constellation/util.py b/constellation/util.py index 4aabf06..a02c2c0 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -37,7 +37,27 @@ def messages_to_onehot(messages, order): return torch.nn.functional.one_hot(messages, num_classes=order).float() -def plot_constellation(ax, constellation): +def plot_constellation( + ax, + constellation, + channel, + decoder, + grid_step=0.05, + noise_samples=1000 +): + """ + Plot a constellation with its decoder and channel noise. + + :param ax: Matplotlib axes to plot on. + :param constellation: Constellation to plot. + :param channel: Channel model to use for generating noise. + :param decoder: Decoder function able to map the constellation points back + to the original messages. + :param grid_step: Grid step used for drawing the decision regions, + expressed as percentage of the total plot width (or equivalently height). + Lower steps makes more precise grids but takes more time to compute. + :param noise_samples: Number of noisy points to sample and plot. + """ ax.grid() order = len(constellation) @@ -76,19 +96,54 @@ def plot_constellation(ax, constellation): va='center', ha='center' ) + # Plot decision regions + regions_extent = 2 * axis_extent + step = grid_step * regions_extent + grid_range = torch.arange(-regions_extent, regions_extent, step) + grid_y, grid_x = torch.meshgrid(grid_range, grid_range) + grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) + + ax.imshow( + grid_images, + extent=( + -regions_extent, regions_extent, + -regions_extent, regions_extent + ), + aspect='auto', + origin='lower', + cmap=color_map, + norm=color_norm, + alpha=0.2 + ) + + # Plot constellation ax.scatter( *zip(*constellation.tolist()), zorder=10, s=60, - c=range(len(constellation)), + c=range(order), edgecolor='black', cmap=color_map, norm=color_norm, ) - # Plot center + # Plot constellation center center = constellation.sum(dim=0) / order ax.scatter( center[0], center[1], marker='X', ) + + # Plot channel noise + noisy_vectors = channel(constellation.repeat(noise_samples, 1)) + + ax.scatter( + *zip(*noisy_vectors.tolist()), + marker='.', + s=5, + c=list(range(order)) * noise_samples, + cmap=color_map, + norm=color_norm, + alpha=0.7, + zorder=8 + ) diff --git a/plot.py b/plot.py index 08ad24b..93b74cd 100644 --- a/plot.py +++ b/plot.py @@ -17,105 +17,24 @@ color_map = matplotlib.cm.Dark2 # Restore model from file model = constellation.ConstellationNet( order=order, - encoder_layers_sizes=(4,), - decoder_layers_sizes=(4,), + encoder_layers_sizes=(8,), + decoder_layers_sizes=(8,), channel_model=constellation.GaussianChannel() ) model.load_state_dict(torch.load(input_file)) model.eval() -# Extract encoding -with torch.no_grad(): - encoded_vectors = model.encoder( - util.messages_to_onehot( - torch.arange(0, order), - order - ) - ) - # Setup plot fig = pyplot.figure() ax = SubplotZero(fig, 111) fig.add_subplot(ax) -# Extend axes symmetrically around zero so that they fit data -axis_extent = max( - abs(encoded_vectors.min()), - abs(encoded_vectors.max()) -) * 1.05 -ax.set_xlim(-axis_extent, axis_extent) -ax.set_ylim(-axis_extent, axis_extent) - -# Hide borders but keep ticks -for direction in ['left', 'bottom', 'right', 'top']: - ax.axis[direction].line.set_color('#00000000') - -# Show zero-centered axes without ticks -for direction in ['xzero', 'yzero']: - axis = ax.axis[direction] - axis.set_visible(True) - axis.set_axisline_style('-|>') - axis.major_ticklabels.set_visible(False) - -# Add axis names -ax.annotate( - 'I', (1, 0.5), xycoords='axes fraction', - xytext=(25, 0), textcoords='offset points', - va='center', ha='right' -) - -ax.annotate( - 'Q', (0.5, 1), xycoords='axes fraction', - xytext=(0, 25), textcoords='offset points', - va='center', ha='center' -) - -ax.grid() - -# Plot decision regions -color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) - -regions_extent = 2 * axis_extent -step = 0.001 * regions_extent -grid_range = torch.arange(-regions_extent, regions_extent, step) -grid_y, grid_x = torch.meshgrid(grid_range, grid_range) -grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) - -ax.imshow( - grid_images, - extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), - aspect='auto', - origin='lower', - cmap=color_map, - norm=color_norm, - alpha=0.1 -) - -# Plot encoded vectors -ax.scatter( - *zip(*encoded_vectors.tolist()), - zorder=10, - s=60, - c=range(len(encoded_vectors)), - edgecolor='black', - cmap=color_map, - norm=color_norm, -) - -# Plot noise -noisy_count = 1000 -noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) - -ax.scatter( - *zip(*noisy_vectors.tolist()), - marker='.', - s=5, - c=list(range(len(encoded_vectors))) * noisy_count, - cmap=color_map, - norm=color_norm, - alpha=0.7, - zorder=8 +constellation = model.get_constellation() +util.plot_constellation( + ax, constellation, + model.channel, model.decoder, + grid_step=0.001, noise_samples=2500 ) pyplot.show() diff --git a/train.py b/train.py index 8dec099..f6d274e 100644 --- a/train.py +++ b/train.py @@ -2,9 +2,7 @@ import constellation from constellation import util import torch from matplotlib import pyplot -import matplotlib from mpl_toolkits.axisartist.axislines import SubplotZero -import time # Number of symbols to learn order = 4 @@ -95,16 +93,18 @@ while True: running_loss += loss.item() if batch % loss_report_batch_skip == loss_report_batch_skip - 1: - ax.clear() - util.plot_constellation(ax, constellation) - fig.canvas.draw() - pyplot.pause(1e-17) - time.sleep(0.1) - print('Batch #{} (size {})'.format(batch + 1, batch_size)) print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) print('\tChange is {}\n'.format(total_change)) + ax.clear() + util.plot_constellation( + ax, constellation, + model.channel, model.decoder + ) + fig.canvas.draw() + pyplot.pause(1e-17) + running_loss = 0 batch += 1