Improve plot legibility
This commit is contained in:
		
							parent
							
								
									8a31b22b83
								
							
						
					
					
						commit
						b97ba61f42
					
				
							
								
								
									
										49
									
								
								plot.py
								
								
								
								
							
							
						
						
									
										49
									
								
								plot.py
								
								
								
								
							|  | @ -12,7 +12,7 @@ order = 4 | |||
| input_file = 'output/constellation-order-{}.pth'.format(order) | ||||
| 
 | ||||
| # Color map used for decision regions | ||||
| color_map = matplotlib.cm.Set1 | ||||
| color_map = matplotlib.cm.Dark2 | ||||
| 
 | ||||
| # Restore model from file | ||||
| model = constellation.ConstellationNet( | ||||
|  | @ -47,50 +47,30 @@ axis_extent = max( | |||
| ax.set_xlim(-axis_extent, axis_extent) | ||||
| ax.set_ylim(-axis_extent, axis_extent) | ||||
| 
 | ||||
| # Hide borders | ||||
| # Hide borders but keep ticks | ||||
| for direction in ['left', 'bottom', 'right', 'top']: | ||||
|     ax.axis[direction].set_visible(False) | ||||
|     ax.axis[direction].line.set_color('#00000000') | ||||
| 
 | ||||
| # Show zero-centered axes | ||||
| # Show zero-centered axes without ticks | ||||
| for direction in ['xzero', 'yzero']: | ||||
|     axis = ax.axis[direction] | ||||
|     axis.set_visible(True) | ||||
|     axis.set_axisline_style("-|>") | ||||
|     axis.set_axisline_style('-|>') | ||||
|     axis.major_ticklabels.set_visible(False) | ||||
| 
 | ||||
| # Configure axes ticks and labels | ||||
| # Add axis names | ||||
| ax.annotate( | ||||
|     'I', (1, 0.5), xycoords='axes fraction', | ||||
|     xytext=(25, 0), textcoords='offset points', | ||||
|     va='center', ha='right' | ||||
| ) | ||||
| 
 | ||||
| ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white') | ||||
| ax.axis['xzero'].set_zorder(9) | ||||
| ax.axis['xzero'].major_ticklabels.set_ha('center') | ||||
| ax.axis['xzero'].major_ticklabels.set_va('top') | ||||
| 
 | ||||
| ax.annotate( | ||||
|     'Q', (0.5, 1), xycoords='axes fraction', | ||||
|     xytext=(0, 25), textcoords='offset points', | ||||
|     va='center', ha='center' | ||||
| ) | ||||
| 
 | ||||
| ax.axis['yzero'].major_ticklabels.set_rotation(-90) | ||||
| ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white') | ||||
| ax.axis['yzero'].set_zorder(9) | ||||
| ax.axis['yzero'].major_ticklabels.set_ha('left') | ||||
| ax.axis['yzero'].major_ticklabels.set_va('center') | ||||
| 
 | ||||
| # Add a single tick on 0 | ||||
| ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0]) | ||||
| ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0]) | ||||
| 
 | ||||
| ax.annotate( | ||||
|     '0', (0, 0), | ||||
|     xytext=(15, -10), textcoords='offset points', | ||||
|     va='center', ha='center' | ||||
| ) | ||||
| 
 | ||||
| ax.grid() | ||||
| 
 | ||||
| # Plot decision regions | ||||
|  | @ -105,8 +85,8 @@ grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2) | |||
| ax.imshow( | ||||
|     grid_images, | ||||
|     extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), | ||||
|     aspect="auto", | ||||
|     origin="lower", | ||||
|     aspect='auto', | ||||
|     origin='lower', | ||||
|     cmap=color_map, | ||||
|     norm=color_norm, | ||||
|     alpha=0.1 | ||||
|  | @ -116,26 +96,21 @@ ax.imshow( | |||
| ax.scatter( | ||||
|     *zip(*encoded_vectors.tolist()), | ||||
|     zorder=10, | ||||
|     s=60, | ||||
|     c=range(len(encoded_vectors)), | ||||
|     edgecolor='black', | ||||
|     cmap=color_map, | ||||
|     norm=color_norm, | ||||
| ) | ||||
| 
 | ||||
| for row in range(order): | ||||
|     ax.annotate( | ||||
|         row + 1, encoded_vectors[row], | ||||
|         xytext=(5, 5), textcoords='offset points', | ||||
|         backgroundcolor='white', zorder=9 | ||||
|     ) | ||||
| 
 | ||||
| # Plot noise | ||||
| noisy_count = 1000 | ||||
| noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) | ||||
| 
 | ||||
| ax.scatter( | ||||
|     *zip(*noisy_vectors.tolist()), | ||||
|     s=1, | ||||
|     marker='.', | ||||
|     s=5, | ||||
|     c=list(range(len(encoded_vectors))) * noisy_count, | ||||
|     cmap=color_map, | ||||
|     norm=color_norm, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue