Fix wrong input dimension for decision region plot
This commit is contained in:
parent
e4457400a6
commit
2b93a5f1bc
|
@ -101,9 +101,12 @@ def plot_constellation(
|
||||||
# Plot decision regions
|
# Plot decision regions
|
||||||
regions_extent = 2 * axis_extent
|
regions_extent = 2 * axis_extent
|
||||||
step = grid_step * regions_extent
|
step = grid_step * regions_extent
|
||||||
|
|
||||||
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
grid_range = torch.arange(-regions_extent, regions_extent, step)
|
||||||
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
|
||||||
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
|
||||||
|
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
|
||||||
|
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
|
||||||
|
|
||||||
ax.imshow(
|
ax.imshow(
|
||||||
grid_images,
|
grid_images,
|
||||||
|
|
Loading…
Reference in New Issue