constellationnet/constellation/util.py

95 lines
2.5 KiB
Python

import torch
import matplotlib
def get_random_messages(count, order):
"""
Generate a list of messages.
:param count: Number of messages to generate.
:param order: Number of possible messages.
:return: One-dimensional vector with each entry being the index of the
generated message which is between 0, inclusive, and `order`, exclusive.
>>> get_random_messages(5)
torch.tensor([0, 2, 0, 3, 4])
"""
return torch.randint(0, order, (count,))
def messages_to_onehot(messages, order):
"""
Convert messages represented as indexes to one-hot encoding.
:param messages: List of messages to convert.
:param order: Number of possible messages.
:return: One-hot encoded messages.
>>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4]))
torch.tensor([
[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
])
"""
return torch.nn.functional.one_hot(messages, num_classes=order).float()
def plot_constellation(ax, constellation):
ax.grid()
order = len(constellation)
color_map = matplotlib.cm.Dark2
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
# Extend axes symmetrically around zero so that they fit data
axis_extent = max(
abs(constellation.min()),
abs(constellation.max())
) * 1.05
ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent)
# Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']:
axis = ax.axis[direction]
axis.set_visible(True)
axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Add axis names
ax.annotate(
'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points',
va='center', ha='right'
)
ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points',
va='center', ha='center'
)
ax.scatter(
*zip(*constellation.tolist()),
zorder=10,
s=60,
c=range(len(constellation)),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Plot center
center = constellation.sum(dim=0) / order
ax.scatter(
center[0], center[1],
marker='X',
)