Browse Source

Fix wrong input dimension for decision region plot

master
Mattéo Delabre 4 years ago
parent
commit
2b93a5f1bc
Signed by: matteo GPG Key ID: AE3FBD02DC583ABB
  1. 5
      constellation/util.py

5
constellation/util.py

@ -101,9 +101,12 @@ def plot_constellation(
# Plot decision regions
regions_extent = 2 * axis_extent
step = grid_step * regions_extent
grid_range = torch.arange(-regions_extent, regions_extent, step)
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
ax.imshow(
grid_images,

Loading…
Cancel
Save