From 2b93a5f1bc418802c104b0bf29eaf025a39f0102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Sun, 15 Dec 2019 23:53:14 -0500 Subject: [PATCH] Fix wrong input dimension for decision region plot --- constellation/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/constellation/util.py b/constellation/util.py index bc8f649..61e8497 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -101,9 +101,12 @@ def plot_constellation( # Plot decision regions regions_extent = 2 * axis_extent step = grid_step * regions_extent + grid_range = torch.arange(-regions_extent, regions_extent, step) grid_y, grid_x = torch.meshgrid(grid_range, grid_range) - grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) + + grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1) + grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape) ax.imshow( grid_images,