Fix wrong input dimension for decision region plot
This commit is contained in:
parent
e4457400a6
commit
2b93a5f1bc
|
@ -101,9 +101,12 @@ def plot_constellation(
|
|||
# Plot decision regions
|
||||
regions_extent = 2 * axis_extent
|
||||
step = grid_step * regions_extent
|
||||
|
||||
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
||||
|
||||
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
|
||||
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
|
||||
|
||||
ax.imshow(
|
||||
grid_images,
|
||||
|
|
Loading…
Reference in New Issue