|
|
@ -101,9 +101,12 @@ def plot_constellation( |
|
|
|
# Plot decision regions |
|
|
|
regions_extent = 2 * axis_extent |
|
|
|
step = grid_step * regions_extent |
|
|
|
|
|
|
|
grid_range = torch.arange(-regions_extent, regions_extent, step) |
|
|
|
grid_y, grid_x = torch.meshgrid(grid_range, grid_range) |
|
|
|
grid_images = decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) |
|
|
|
|
|
|
|
grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1) |
|
|
|
grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape) |
|
|
|
|
|
|
|
ax.imshow( |
|
|
|
grid_images, |
|
|
|