Plot noise

This commit is contained in:
Mattéo Delabre 2019-12-15 01:07:05 -05:00
parent 6ea0e653c1
commit 9a87b322d4
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 23 additions and 6 deletions

29
plot.py
View File

@ -40,12 +40,12 @@ ax = SubplotZero(fig, 111)
fig.add_subplot(ax) fig.add_subplot(ax)
# Extend axes symmetrically around zero so that they fit data # Extend axes symmetrically around zero so that they fit data
extent = max( axis_extent = max(
abs(encoded_vectors.min()), abs(encoded_vectors.min()),
abs(encoded_vectors.max()) abs(encoded_vectors.max())
) * 1.05 ) * 1.05
ax.set_xlim(-extent, extent) ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-extent, extent) ax.set_ylim(-axis_extent, axis_extent)
# Hide borders # Hide borders
for direction in ['left', 'bottom', 'right', 'top']: for direction in ['left', 'bottom', 'right', 'top']:
@ -65,6 +65,7 @@ ax.annotate(
) )
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
ax.axis['xzero'].set_zorder(9)
ax.axis['xzero'].major_ticklabels.set_ha('center') ax.axis['xzero'].major_ticklabels.set_ha('center')
ax.axis['xzero'].major_ticklabels.set_va('top') ax.axis['xzero'].major_ticklabels.set_va('top')
@ -76,6 +77,7 @@ ax.annotate(
ax.axis['yzero'].major_ticklabels.set_rotation(-90) ax.axis['yzero'].major_ticklabels.set_rotation(-90)
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
ax.axis['yzero'].set_zorder(9)
ax.axis['yzero'].major_ticklabels.set_ha('left') ax.axis['yzero'].major_ticklabels.set_ha('left')
ax.axis['yzero'].major_ticklabels.set_va('center') ax.axis['yzero'].major_ticklabels.set_va('center')
@ -94,13 +96,15 @@ ax.grid()
# Plot decision regions # Plot decision regions
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
step = 0.001 * extent regions_extent = 2 * axis_extent
grid_range = torch.arange(-extent, extent, step) step = 0.001 * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range) grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow( ax.imshow(
grid_images, extent=(-extent, extent, -extent, extent), grid_images,
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
aspect="auto", aspect="auto",
origin="lower", origin="lower",
cmap=color_map, cmap=color_map,
@ -125,4 +129,17 @@ for row in range(order):
backgroundcolor='white', zorder=9 backgroundcolor='white', zorder=9
) )
# Plot noise
noisy_count = 1000
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
ax.scatter(
*zip(*noisy_vectors.tolist()),
s=1,
c=list(range(len(encoded_vectors))) * noisy_count,
cmap=color_map,
norm=color_norm,
zorder=8
)
pyplot.show() pyplot.show()