diff --git a/plot.py b/plot.py index c95fed2..a5f1b25 100644 --- a/plot.py +++ b/plot.py @@ -40,12 +40,12 @@ ax = SubplotZero(fig, 111) fig.add_subplot(ax) # Extend axes symmetrically around zero so that they fit data -extent = max( +axis_extent = max( abs(encoded_vectors.min()), abs(encoded_vectors.max()) ) * 1.05 -ax.set_xlim(-extent, extent) -ax.set_ylim(-extent, extent) +ax.set_xlim(-axis_extent, axis_extent) +ax.set_ylim(-axis_extent, axis_extent) # Hide borders for direction in ['left', 'bottom', 'right', 'top']: @@ -65,6 +65,7 @@ ax.annotate( ) ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') +ax.axis['xzero'].set_zorder(9) ax.axis['xzero'].major_ticklabels.set_ha('center') ax.axis['xzero'].major_ticklabels.set_va('top') @@ -76,6 +77,7 @@ ax.annotate( ax.axis['yzero'].major_ticklabels.set_rotation(-90) ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') +ax.axis['yzero'].set_zorder(9) ax.axis['yzero'].major_ticklabels.set_ha('left') ax.axis['yzero'].major_ticklabels.set_va('center') @@ -94,13 +96,15 @@ ax.grid() # Plot decision regions color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) -step = 0.001 * extent -grid_range = torch.arange(-extent, extent, step) +regions_extent = 2 * axis_extent +step = 0.001 * regions_extent +grid_range = torch.arange(-regions_extent, regions_extent, step) grid_y, grid_x = torch.meshgrid(grid_range, grid_range) grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) ax.imshow( - grid_images, extent=(-extent, extent, -extent, extent), + grid_images, + extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), aspect="auto", origin="lower", cmap=color_map, @@ -125,4 +129,17 @@ for row in range(order): backgroundcolor='white', zorder=9 ) +# Plot noise +noisy_count = 1000 +noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) + +ax.scatter( + *zip(*noisy_vectors.tolist()), + s=1, + c=list(range(len(encoded_vectors))) * noisy_count, + cmap=color_map, + norm=color_norm, + zorder=8 +) + pyplot.show()