constellationnet/constellation/util.py

176 lines
4.9 KiB
Python

import itertools
import torch
import matplotlib
from matplotlib.colors import ListedColormap
import seaborn
def get_random_messages(count, order):
"""
Generate a list of messages.
:param count: Number of messages to generate.
:param order: Number of possible messages.
:return: One-dimensional vector with each entry being the index of the
generated message which is between 0, inclusive, and `order`, exclusive.
>>> get_random_messages(5)
torch.tensor([0, 2, 0, 3, 4])
"""
return torch.randint(0, order, (count,))
def messages_to_onehot(messages, order):
"""
Convert messages represented as indexes to one-hot encoding.
:param messages: List of messages to convert.
:param order: Number of possible messages.
:return: One-hot encoded messages.
>>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4]))
torch.tensor([
[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
])
"""
return torch.nn.functional.one_hot(messages, num_classes=order).float()
def plot_constellation(
ax,
constellation,
channel,
decoder,
grid_step=0.05,
noise_samples=1000
):
"""
Plot a constellation with its decoder and channel noise.
:param ax: Matplotlib axes to plot on.
:param constellation: Constellation to plot.
:param channel: Channel model to use for generating noise.
:param decoder: Decoder function able to map the constellation points back
to the original messages.
:param grid_step: Grid step used for drawing the decision regions,
expressed as percentage of the total plot width (or equivalently height).
Lower steps makes more precise grids but takes more time to compute.
:param noise_samples: Number of noisy points to sample and plot.
"""
ax.grid()
order = len(constellation)
color_map = ListedColormap(seaborn.color_palette('husl', n_colors=order))
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
# Extend axes symmetrically around zero so that they fit data
axis_extent = max(
abs(constellation.min()),
abs(constellation.max())
) * 1.05
ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent)
# Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']:
axis = ax.axis[direction]
axis.set_visible(True)
axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Add axis names
ax.annotate(
'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points',
va='center', ha='right'
)
ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points',
va='center', ha='center'
)
# Plot decision regions
regions_extent = 2 * axis_extent
step = grid_step * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
ax.imshow(
grid_images,
extent=(
-regions_extent, regions_extent,
-regions_extent, regions_extent
),
aspect='auto',
origin='lower',
cmap=color_map,
norm=color_norm,
alpha=0.2
)
# Plot constellation
ax.scatter(
*zip(*constellation.tolist()),
zorder=10,
s=60,
c=range(order),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Plot constellation center
center = constellation.sum(dim=0) / order
ax.scatter(
center[0], center[1],
marker='X',
)
# Plot channel noise
if noise_samples > 0:
noisy_vectors = channel(constellation.repeat(noise_samples, 1))
ax.scatter(
*zip(*noisy_vectors.tolist()),
marker='.',
s=5,
c=list(range(order)) * noise_samples,
cmap=color_map,
norm=color_norm,
alpha=0.3,
zorder=8
)
def product_dict(**kwargs):
"""
Compute cartesian product of a set of parameters.
>>> list(product_dict(first=[1, 2, 3], second=['a', 'b']))
[{'first': 1, 'second': 'a'},
{'first': 1, 'second': 'b'},
{'first': 2, 'second': 'a'},
{'first': 2, 'second': 'b'},
{'first': 3, 'second': 'a'},
{'first': 3, 'second': 'b'}]
"""
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))