import torch import matplotlib def get_random_messages(count, order): """ Generate a list of messages. :param count: Number of messages to generate. :param order: Number of possible messages. :return: One-dimensional vector with each entry being the index of the generated message which is between 0, inclusive, and `order`, exclusive. >>> get_random_messages(5) torch.tensor([0, 2, 0, 3, 4]) """ return torch.randint(0, order, (count,)) def messages_to_onehot(messages, order): """ Convert messages represented as indexes to one-hot encoding. :param messages: List of messages to convert. :param order: Number of possible messages. :return: One-hot encoded messages. >>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4])) torch.tensor([ [1., 0., 0., 0., 0.], [0., 0., 1., 0., 0.], [1., 0., 0., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], ]) """ return torch.nn.functional.one_hot(messages, num_classes=order).float() def plot_constellation(ax, constellation): ax.grid() order = len(constellation) color_map = matplotlib.cm.Dark2 color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) # Extend axes symmetrically around zero so that they fit data axis_extent = max( abs(constellation.min()), abs(constellation.max()) ) * 1.05 ax.set_xlim(-axis_extent, axis_extent) ax.set_ylim(-axis_extent, axis_extent) # Hide borders but keep ticks for direction in ['left', 'bottom', 'right', 'top']: ax.axis[direction].line.set_color('#00000000') # Show zero-centered axes without ticks for direction in ['xzero', 'yzero']: axis = ax.axis[direction] axis.set_visible(True) axis.set_axisline_style('-|>') axis.major_ticklabels.set_visible(False) # Add axis names ax.annotate( 'I', (1, 0.5), xycoords='axes fraction', xytext=(25, 0), textcoords='offset points', va='center', ha='right' ) ax.annotate( 'Q', (0.5, 1), xycoords='axes fraction', xytext=(0, 25), textcoords='offset points', va='center', ha='center' ) ax.scatter( *zip(*constellation.tolist()), zorder=10, s=60, c=range(len(constellation)), edgecolor='black', cmap=color_map, norm=color_norm, ) # Plot center center = constellation.sum(dim=0) / order ax.scatter( center[0], center[1], marker='X', )