diff --git a/constellation/util.py b/constellation/util.py index bc8f649..61e8497 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -101,9 +101,12 @@ def plot_constellation( # Plot decision regions regions_extent = 2 * axis_extent step = grid_step * regions_extent + grid_range = torch.arange(-regions_extent, regions_extent, step) grid_y, grid_x = torch.meshgrid(grid_range, grid_range) - grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) + + grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1) + grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape) ax.imshow( grid_images,