Plot noise
This commit is contained in:
parent
6ea0e653c1
commit
9a87b322d4
29
plot.py
29
plot.py
|
@ -40,12 +40,12 @@ ax = SubplotZero(fig, 111)
|
||||||
fig.add_subplot(ax)
|
fig.add_subplot(ax)
|
||||||
|
|
||||||
# Extend axes symmetrically around zero so that they fit data
|
# Extend axes symmetrically around zero so that they fit data
|
||||||
extent = max(
|
axis_extent = max(
|
||||||
abs(encoded_vectors.min()),
|
abs(encoded_vectors.min()),
|
||||||
abs(encoded_vectors.max())
|
abs(encoded_vectors.max())
|
||||||
) * 1.05
|
) * 1.05
|
||||||
ax.set_xlim(-extent, extent)
|
ax.set_xlim(-axis_extent, axis_extent)
|
||||||
ax.set_ylim(-extent, extent)
|
ax.set_ylim(-axis_extent, axis_extent)
|
||||||
|
|
||||||
# Hide borders
|
# Hide borders
|
||||||
for direction in ['left', 'bottom', 'right', 'top']:
|
for direction in ['left', 'bottom', 'right', 'top']:
|
||||||
|
@ -65,6 +65,7 @@ ax.annotate(
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
|
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
|
||||||
|
ax.axis['xzero'].set_zorder(9)
|
||||||
ax.axis['xzero'].major_ticklabels.set_ha('center')
|
ax.axis['xzero'].major_ticklabels.set_ha('center')
|
||||||
ax.axis['xzero'].major_ticklabels.set_va('top')
|
ax.axis['xzero'].major_ticklabels.set_va('top')
|
||||||
|
|
||||||
|
@ -76,6 +77,7 @@ ax.annotate(
|
||||||
|
|
||||||
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
|
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
|
||||||
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
|
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
|
||||||
|
ax.axis['yzero'].set_zorder(9)
|
||||||
ax.axis['yzero'].major_ticklabels.set_ha('left')
|
ax.axis['yzero'].major_ticklabels.set_ha('left')
|
||||||
ax.axis['yzero'].major_ticklabels.set_va('center')
|
ax.axis['yzero'].major_ticklabels.set_va('center')
|
||||||
|
|
||||||
|
@ -94,13 +96,15 @@ ax.grid()
|
||||||
# Plot decision regions
|
# Plot decision regions
|
||||||
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
||||||
|
|
||||||
step = 0.001 * extent
|
regions_extent = 2 * axis_extent
|
||||||
grid_range = torch.arange(-extent, extent, step)
|
step = 0.001 * regions_extent
|
||||||
|
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
||||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||||
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
||||||
|
|
||||||
ax.imshow(
|
ax.imshow(
|
||||||
grid_images, extent=(-extent, extent, -extent, extent),
|
grid_images,
|
||||||
|
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
origin="lower",
|
origin="lower",
|
||||||
cmap=color_map,
|
cmap=color_map,
|
||||||
|
@ -125,4 +129,17 @@ for row in range(order):
|
||||||
backgroundcolor='white', zorder=9
|
backgroundcolor='white', zorder=9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plot noise
|
||||||
|
noisy_count = 1000
|
||||||
|
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
*zip(*noisy_vectors.tolist()),
|
||||||
|
s=1,
|
||||||
|
c=list(range(len(encoded_vectors))) * noisy_count,
|
||||||
|
cmap=color_map,
|
||||||
|
norm=color_norm,
|
||||||
|
zorder=8
|
||||||
|
)
|
||||||
|
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
|
|
Loading…
Reference in New Issue