Fix wrong input dimension for decision region plot

This commit is contained in:
Mattéo Delabre 2019-12-15 23:53:14 -05:00
parent e4457400a6
commit 2b93a5f1bc
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 4 additions and 1 deletions

View File

@ -101,9 +101,12 @@ def plot_constellation(
# Plot decision regions
regions_extent = 2 * axis_extent
step = grid_step * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
ax.imshow(
grid_images,