import itertools import torch import matplotlib from matplotlib.colors import ListedColormap import seaborn def get_random_messages(count, order): """ Generate a list of messages. :param count: Number of messages to generate. :param order: Number of possible messages. :return: One-dimensional vector with each entry being the index of the generated message which is between 0, inclusive, and `order`, exclusive. >>> get_random_messages(5) torch.tensor([0, 2, 0, 3, 4]) """ return torch.randint(0, order, (count,)) def messages_to_onehot(messages, order): """ Convert messages represented as indexes to one-hot encoding. :param messages: List of messages to convert. :param order: Number of possible messages. :return: One-hot encoded messages. >>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4])) torch.tensor([ [1., 0., 0., 0., 0.], [0., 0., 1., 0., 0.], [1., 0., 0., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], ]) """ return torch.nn.functional.one_hot(messages, num_classes=order).float() def plot_constellation( ax, constellation, channel, decoder, grid_step=0.05, noise_samples=1000 ): """ Plot a constellation with its decoder and channel noise. :param ax: Matplotlib axes to plot on. :param constellation: Constellation to plot. :param channel: Channel model to use for generating noise. :param decoder: Decoder function able to map the constellation points back to the original messages. :param grid_step: Grid step used for drawing the decision regions, expressed as percentage of the total plot width (or equivalently height). Lower steps makes more precise grids but takes more time to compute. :param noise_samples: Number of noisy points to sample and plot. """ ax.grid() order = len(constellation) color_map = ListedColormap(seaborn.color_palette('husl', n_colors=order)) color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) # Extend axes symmetrically around zero so that they fit data axis_extent = max( abs(constellation.min()), abs(constellation.max()) ) * 1.05 ax.set_xlim(-axis_extent, axis_extent) ax.set_ylim(-axis_extent, axis_extent) # Hide borders but keep ticks for direction in ['left', 'bottom', 'right', 'top']: ax.axis[direction].line.set_color('#00000000') # Show zero-centered axes without ticks for direction in ['xzero', 'yzero']: axis = ax.axis[direction] axis.set_visible(True) axis.set_axisline_style('-|>') axis.major_ticklabels.set_visible(False) # Add axis names ax.annotate( 'I', (1, 0.5), xycoords='axes fraction', xytext=(25, 0), textcoords='offset points', va='center', ha='right' ) ax.annotate( 'Q', (0.5, 1), xycoords='axes fraction', xytext=(0, 25), textcoords='offset points', va='center', ha='center' ) # Plot decision regions regions_extent = 2 * axis_extent step = grid_step * regions_extent grid_range = torch.arange(-regions_extent, regions_extent, step) grid_y, grid_x = torch.meshgrid(grid_range, grid_range) grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1) grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape) ax.imshow( grid_images, extent=( -regions_extent, regions_extent, -regions_extent, regions_extent ), aspect='auto', origin='lower', cmap=color_map, norm=color_norm, alpha=0.2 ) # Plot constellation ax.scatter( *zip(*constellation.tolist()), zorder=10, s=60, c=range(order), edgecolor='black', cmap=color_map, norm=color_norm, ) # Plot constellation center center = constellation.sum(dim=0) / order ax.scatter( center[0], center[1], marker='X', ) # Plot channel noise if noise_samples > 0: noisy_vectors = channel(constellation.repeat(noise_samples, 1)) ax.scatter( *zip(*noisy_vectors.tolist()), marker='.', s=5, c=list(range(order)) * noise_samples, cmap=color_map, norm=color_norm, alpha=0.3, zorder=8 ) def product_dict(**kwargs): """ Compute cartesian product of a set of parameters. >>> list(product_dict(first=[1, 2, 3], second=['a', 'b'])) [{'first': 1, 'second': 'a'}, {'first': 1, 'second': 'b'}, {'first': 2, 'second': 'a'}, {'first': 2, 'second': 'b'}, {'first': 3, 'second': 'a'}, {'first': 3, 'second': 'b'}] """ keys = kwargs.keys() vals = kwargs.values() for instance in itertools.product(*vals): yield dict(zip(keys, instance))