176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
import itertools
|
|
import torch
|
|
import matplotlib
|
|
from matplotlib.colors import ListedColormap
|
|
import seaborn
|
|
|
|
|
|
def get_random_messages(count, order):
|
|
"""
|
|
Generate a list of messages.
|
|
|
|
:param count: Number of messages to generate.
|
|
:param order: Number of possible messages.
|
|
:return: One-dimensional vector with each entry being the index of the
|
|
generated message which is between 0, inclusive, and `order`, exclusive.
|
|
|
|
>>> get_random_messages(5)
|
|
torch.tensor([0, 2, 0, 3, 4])
|
|
"""
|
|
return torch.randint(0, order, (count,))
|
|
|
|
|
|
def messages_to_onehot(messages, order):
|
|
"""
|
|
Convert messages represented as indexes to one-hot encoding.
|
|
|
|
:param messages: List of messages to convert.
|
|
:param order: Number of possible messages.
|
|
:return: One-hot encoded messages.
|
|
|
|
>>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4]))
|
|
torch.tensor([
|
|
[1., 0., 0., 0., 0.],
|
|
[0., 0., 1., 0., 0.],
|
|
[1., 0., 0., 0., 0.],
|
|
[0., 0., 0., 1., 0.],
|
|
[0., 0., 0., 0., 1.],
|
|
])
|
|
"""
|
|
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
|
|
|
|
|
def plot_constellation(
|
|
ax,
|
|
constellation,
|
|
channel,
|
|
decoder,
|
|
grid_step=0.05,
|
|
noise_samples=1000
|
|
):
|
|
"""
|
|
Plot a constellation with its decoder and channel noise.
|
|
|
|
:param ax: Matplotlib axes to plot on.
|
|
:param constellation: Constellation to plot.
|
|
:param channel: Channel model to use for generating noise.
|
|
:param decoder: Decoder function able to map the constellation points back
|
|
to the original messages.
|
|
:param grid_step: Grid step used for drawing the decision regions,
|
|
expressed as percentage of the total plot width (or equivalently height).
|
|
Lower steps makes more precise grids but takes more time to compute.
|
|
:param noise_samples: Number of noisy points to sample and plot.
|
|
"""
|
|
ax.grid()
|
|
|
|
order = len(constellation)
|
|
color_map = ListedColormap(seaborn.color_palette('husl', n_colors=order))
|
|
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
|
|
|
# Extend axes symmetrically around zero so that they fit data
|
|
axis_extent = max(
|
|
abs(constellation.min()),
|
|
abs(constellation.max())
|
|
) * 1.05
|
|
ax.set_xlim(-axis_extent, axis_extent)
|
|
ax.set_ylim(-axis_extent, axis_extent)
|
|
|
|
# Hide borders but keep ticks
|
|
for direction in ['left', 'bottom', 'right', 'top']:
|
|
ax.axis[direction].line.set_color('#00000000')
|
|
|
|
# Show zero-centered axes without ticks
|
|
for direction in ['xzero', 'yzero']:
|
|
axis = ax.axis[direction]
|
|
axis.set_visible(True)
|
|
axis.set_axisline_style('-|>')
|
|
axis.major_ticklabels.set_visible(False)
|
|
|
|
# Add axis names
|
|
ax.annotate(
|
|
'I', (1, 0.5), xycoords='axes fraction',
|
|
xytext=(25, 0), textcoords='offset points',
|
|
va='center', ha='right'
|
|
)
|
|
|
|
ax.annotate(
|
|
'Q', (0.5, 1), xycoords='axes fraction',
|
|
xytext=(0, 25), textcoords='offset points',
|
|
va='center', ha='center'
|
|
)
|
|
|
|
# Plot decision regions
|
|
regions_extent = 2 * axis_extent
|
|
step = grid_step * regions_extent
|
|
|
|
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
|
|
|
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
|
|
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
|
|
|
|
ax.imshow(
|
|
grid_images,
|
|
extent=(
|
|
-regions_extent, regions_extent,
|
|
-regions_extent, regions_extent
|
|
),
|
|
aspect='auto',
|
|
origin='lower',
|
|
cmap=color_map,
|
|
norm=color_norm,
|
|
alpha=0.2
|
|
)
|
|
|
|
# Plot constellation
|
|
ax.scatter(
|
|
*zip(*constellation.tolist()),
|
|
zorder=10,
|
|
s=60,
|
|
c=range(order),
|
|
edgecolor='black',
|
|
cmap=color_map,
|
|
norm=color_norm,
|
|
)
|
|
|
|
# Plot constellation center
|
|
center = constellation.sum(dim=0) / order
|
|
ax.scatter(
|
|
center[0], center[1],
|
|
marker='X',
|
|
)
|
|
|
|
# Plot channel noise
|
|
if noise_samples > 0:
|
|
noisy_vectors = channel(constellation.repeat(noise_samples, 1))
|
|
|
|
ax.scatter(
|
|
*zip(*noisy_vectors.tolist()),
|
|
marker='.',
|
|
s=5,
|
|
c=list(range(order)) * noise_samples,
|
|
cmap=color_map,
|
|
norm=color_norm,
|
|
alpha=0.3,
|
|
zorder=8
|
|
)
|
|
|
|
|
|
def product_dict(**kwargs):
|
|
"""
|
|
Compute cartesian product of a set of parameters.
|
|
|
|
>>> list(product_dict(first=[1, 2, 3], second=['a', 'b']))
|
|
[{'first': 1, 'second': 'a'},
|
|
{'first': 1, 'second': 'b'},
|
|
{'first': 2, 'second': 'a'},
|
|
{'first': 2, 'second': 'b'},
|
|
{'first': 3, 'second': 'a'},
|
|
{'first': 3, 'second': 'b'}]
|
|
"""
|
|
keys = kwargs.keys()
|
|
vals = kwargs.values()
|
|
|
|
for instance in itertools.product(*vals):
|
|
yield dict(zip(keys, instance))
|