diff --git a/plot.py b/plot.py index 8a12445..08ad24b 100644 --- a/plot.py +++ b/plot.py @@ -12,7 +12,7 @@ order = 4 input_file = 'output/constellation-order-{}.pth'.format(order) # Color map used for decision regions -color_map = matplotlib.cm.Set1 +color_map = matplotlib.cm.Dark2 # Restore model from file model = constellation.ConstellationNet( @@ -47,50 +47,30 @@ axis_extent = max( ax.set_xlim(-axis_extent, axis_extent) ax.set_ylim(-axis_extent, axis_extent) -# Hide borders +# Hide borders but keep ticks for direction in ['left', 'bottom', 'right', 'top']: - ax.axis[direction].set_visible(False) + ax.axis[direction].line.set_color('#00000000') -# Show zero-centered axes +# Show zero-centered axes without ticks for direction in ['xzero', 'yzero']: axis = ax.axis[direction] axis.set_visible(True) - axis.set_axisline_style("-|>") + axis.set_axisline_style('-|>') + axis.major_ticklabels.set_visible(False) -# Configure axes ticks and labels +# Add axis names ax.annotate( 'I', (1, 0.5), xycoords='axes fraction', xytext=(25, 0), textcoords='offset points', va='center', ha='right' ) -ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') -ax.axis['xzero'].set_zorder(9) -ax.axis['xzero'].major_ticklabels.set_ha('center') -ax.axis['xzero'].major_ticklabels.set_va('top') - ax.annotate( 'Q', (0.5, 1), xycoords='axes fraction', xytext=(0, 25), textcoords='offset points', va='center', ha='center' ) -ax.axis['yzero'].major_ticklabels.set_rotation(-90) -ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') -ax.axis['yzero'].set_zorder(9) -ax.axis['yzero'].major_ticklabels.set_ha('left') -ax.axis['yzero'].major_ticklabels.set_va('center') - -# Add a single tick on 0 -ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0]) -ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0]) - -ax.annotate( - '0', (0, 0), - xytext=(15, -10), textcoords='offset points', - va='center', ha='center' -) - ax.grid() # Plot decision regions @@ -105,8 +85,8 @@ grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) ax.imshow( grid_images, extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), - aspect="auto", - origin="lower", + aspect='auto', + origin='lower', cmap=color_map, norm=color_norm, alpha=0.1 @@ -116,26 +96,21 @@ ax.imshow( ax.scatter( *zip(*encoded_vectors.tolist()), zorder=10, + s=60, c=range(len(encoded_vectors)), edgecolor='black', cmap=color_map, norm=color_norm, ) -for row in range(order): - ax.annotate( - row + 1, encoded_vectors[row], - xytext=(5, 5), textcoords='offset points', - backgroundcolor='white', zorder=9 - ) - # Plot noise noisy_count = 1000 noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) ax.scatter( *zip(*noisy_vectors.tolist()), - s=1, + marker='.', + s=5, c=list(range(len(encoded_vectors))) * noisy_count, cmap=color_map, norm=color_norm,