Share plotting code between plot.py and train.py
This commit is contained in:
parent
34cb3b863b
commit
3638749998
|
@ -37,7 +37,27 @@ def messages_to_onehot(messages, order):
|
||||||
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
||||||
|
|
||||||
|
|
||||||
def plot_constellation(ax, constellation):
|
def plot_constellation(
|
||||||
|
ax,
|
||||||
|
constellation,
|
||||||
|
channel,
|
||||||
|
decoder,
|
||||||
|
grid_step=0.05,
|
||||||
|
noise_samples=1000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Plot a constellation with its decoder and channel noise.
|
||||||
|
|
||||||
|
:param ax: Matplotlib axes to plot on.
|
||||||
|
:param constellation: Constellation to plot.
|
||||||
|
:param channel: Channel model to use for generating noise.
|
||||||
|
:param decoder: Decoder function able to map the constellation points back
|
||||||
|
to the original messages.
|
||||||
|
:param grid_step: Grid step used for drawing the decision regions,
|
||||||
|
expressed as percentage of the total plot width (or equivalently height).
|
||||||
|
Lower steps makes more precise grids but takes more time to compute.
|
||||||
|
:param noise_samples: Number of noisy points to sample and plot.
|
||||||
|
"""
|
||||||
ax.grid()
|
ax.grid()
|
||||||
|
|
||||||
order = len(constellation)
|
order = len(constellation)
|
||||||
|
@ -76,19 +96,54 @@ def plot_constellation(ax, constellation):
|
||||||
va='center', ha='center'
|
va='center', ha='center'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plot decision regions
|
||||||
|
regions_extent = 2 * axis_extent
|
||||||
|
step = grid_step * regions_extent
|
||||||
|
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
||||||
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||||
|
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
||||||
|
|
||||||
|
ax.imshow(
|
||||||
|
grid_images,
|
||||||
|
extent=(
|
||||||
|
-regions_extent, regions_extent,
|
||||||
|
-regions_extent, regions_extent
|
||||||
|
),
|
||||||
|
aspect='auto',
|
||||||
|
origin='lower',
|
||||||
|
cmap=color_map,
|
||||||
|
norm=color_norm,
|
||||||
|
alpha=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot constellation
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
*zip(*constellation.tolist()),
|
*zip(*constellation.tolist()),
|
||||||
zorder=10,
|
zorder=10,
|
||||||
s=60,
|
s=60,
|
||||||
c=range(len(constellation)),
|
c=range(order),
|
||||||
edgecolor='black',
|
edgecolor='black',
|
||||||
cmap=color_map,
|
cmap=color_map,
|
||||||
norm=color_norm,
|
norm=color_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot center
|
# Plot constellation center
|
||||||
center = constellation.sum(dim=0) / order
|
center = constellation.sum(dim=0) / order
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
center[0], center[1],
|
center[0], center[1],
|
||||||
marker='X',
|
marker='X',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plot channel noise
|
||||||
|
noisy_vectors = channel(constellation.repeat(noise_samples, 1))
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
*zip(*noisy_vectors.tolist()),
|
||||||
|
marker='.',
|
||||||
|
s=5,
|
||||||
|
c=list(range(order)) * noise_samples,
|
||||||
|
cmap=color_map,
|
||||||
|
norm=color_norm,
|
||||||
|
alpha=0.7,
|
||||||
|
zorder=8
|
||||||
|
)
|
||||||
|
|
95
plot.py
95
plot.py
|
@ -17,105 +17,24 @@ color_map = matplotlib.cm.Dark2
|
||||||
# Restore model from file
|
# Restore model from file
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(4,),
|
encoder_layers_sizes=(8,),
|
||||||
decoder_layers_sizes=(4,),
|
decoder_layers_sizes=(8,),
|
||||||
channel_model=constellation.GaussianChannel()
|
channel_model=constellation.GaussianChannel()
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_state_dict(torch.load(input_file))
|
model.load_state_dict(torch.load(input_file))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Extract encoding
|
|
||||||
with torch.no_grad():
|
|
||||||
encoded_vectors = model.encoder(
|
|
||||||
util.messages_to_onehot(
|
|
||||||
torch.arange(0, order),
|
|
||||||
order
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup plot
|
# Setup plot
|
||||||
fig = pyplot.figure()
|
fig = pyplot.figure()
|
||||||
ax = SubplotZero(fig, 111)
|
ax = SubplotZero(fig, 111)
|
||||||
fig.add_subplot(ax)
|
fig.add_subplot(ax)
|
||||||
|
|
||||||
# Extend axes symmetrically around zero so that they fit data
|
constellation = model.get_constellation()
|
||||||
axis_extent = max(
|
util.plot_constellation(
|
||||||
abs(encoded_vectors.min()),
|
ax, constellation,
|
||||||
abs(encoded_vectors.max())
|
model.channel, model.decoder,
|
||||||
) * 1.05
|
grid_step=0.001, noise_samples=2500
|
||||||
ax.set_xlim(-axis_extent, axis_extent)
|
|
||||||
ax.set_ylim(-axis_extent, axis_extent)
|
|
||||||
|
|
||||||
# Hide borders but keep ticks
|
|
||||||
for direction in ['left', 'bottom', 'right', 'top']:
|
|
||||||
ax.axis[direction].line.set_color('#00000000')
|
|
||||||
|
|
||||||
# Show zero-centered axes without ticks
|
|
||||||
for direction in ['xzero', 'yzero']:
|
|
||||||
axis = ax.axis[direction]
|
|
||||||
axis.set_visible(True)
|
|
||||||
axis.set_axisline_style('-|>')
|
|
||||||
axis.major_ticklabels.set_visible(False)
|
|
||||||
|
|
||||||
# Add axis names
|
|
||||||
ax.annotate(
|
|
||||||
'I', (1, 0.5), xycoords='axes fraction',
|
|
||||||
xytext=(25, 0), textcoords='offset points',
|
|
||||||
va='center', ha='right'
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.annotate(
|
|
||||||
'Q', (0.5, 1), xycoords='axes fraction',
|
|
||||||
xytext=(0, 25), textcoords='offset points',
|
|
||||||
va='center', ha='center'
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.grid()
|
|
||||||
|
|
||||||
# Plot decision regions
|
|
||||||
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
|
||||||
|
|
||||||
regions_extent = 2 * axis_extent
|
|
||||||
step = 0.001 * regions_extent
|
|
||||||
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
|
||||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
|
||||||
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
|
||||||
|
|
||||||
ax.imshow(
|
|
||||||
grid_images,
|
|
||||||
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
|
|
||||||
aspect='auto',
|
|
||||||
origin='lower',
|
|
||||||
cmap=color_map,
|
|
||||||
norm=color_norm,
|
|
||||||
alpha=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot encoded vectors
|
|
||||||
ax.scatter(
|
|
||||||
*zip(*encoded_vectors.tolist()),
|
|
||||||
zorder=10,
|
|
||||||
s=60,
|
|
||||||
c=range(len(encoded_vectors)),
|
|
||||||
edgecolor='black',
|
|
||||||
cmap=color_map,
|
|
||||||
norm=color_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot noise
|
|
||||||
noisy_count = 1000
|
|
||||||
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
|
|
||||||
|
|
||||||
ax.scatter(
|
|
||||||
*zip(*noisy_vectors.tolist()),
|
|
||||||
marker='.',
|
|
||||||
s=5,
|
|
||||||
c=list(range(len(encoded_vectors))) * noisy_count,
|
|
||||||
cmap=color_map,
|
|
||||||
norm=color_norm,
|
|
||||||
alpha=0.7,
|
|
||||||
zorder=8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
|
|
16
train.py
16
train.py
|
@ -2,9 +2,7 @@ import constellation
|
||||||
from constellation import util
|
from constellation import util
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
import matplotlib
|
|
||||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
import time
|
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 4
|
order = 4
|
||||||
|
@ -95,16 +93,18 @@ while True:
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
||||||
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
||||||
ax.clear()
|
|
||||||
util.plot_constellation(ax, constellation)
|
|
||||||
fig.canvas.draw()
|
|
||||||
pyplot.pause(1e-17)
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
||||||
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
||||||
print('\tChange is {}\n'.format(total_change))
|
print('\tChange is {}\n'.format(total_change))
|
||||||
|
|
||||||
|
ax.clear()
|
||||||
|
util.plot_constellation(
|
||||||
|
ax, constellation,
|
||||||
|
model.channel, model.decoder
|
||||||
|
)
|
||||||
|
fig.canvas.draw()
|
||||||
|
pyplot.pause(1e-17)
|
||||||
|
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
batch += 1
|
batch += 1
|
||||||
|
|
Loading…
Reference in New Issue