Improve plot legibility

This commit is contained in:
Mattéo Delabre 2019-12-15 02:07:52 -05:00
parent 8a31b22b83
commit b97ba61f42
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
1 changed files with 12 additions and 37 deletions

49
plot.py
View File

@ -12,7 +12,7 @@ order = 4
input_file = 'output/constellation-order-{}.pth'.format(order) input_file = 'output/constellation-order-{}.pth'.format(order)
# Color map used for decision regions # Color map used for decision regions
color_map = matplotlib.cm.Set1 color_map = matplotlib.cm.Dark2
# Restore model from file # Restore model from file
model = constellation.ConstellationNet( model = constellation.ConstellationNet(
@ -47,50 +47,30 @@ axis_extent = max(
ax.set_xlim(-axis_extent, axis_extent) ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent) ax.set_ylim(-axis_extent, axis_extent)
# Hide borders # Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']: for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].set_visible(False) ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes # Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']: for direction in ['xzero', 'yzero']:
axis = ax.axis[direction] axis = ax.axis[direction]
axis.set_visible(True) axis.set_visible(True)
axis.set_axisline_style("-|>") axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Configure axes ticks and labels # Add axis names
ax.annotate( ax.annotate(
'I', (1, 0.5), xycoords='axes fraction', 'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points', xytext=(25, 0), textcoords='offset points',
va='center', ha='right' va='center', ha='right'
) )
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
ax.axis['xzero'].set_zorder(9)
ax.axis['xzero'].major_ticklabels.set_ha('center')
ax.axis['xzero'].major_ticklabels.set_va('top')
ax.annotate( ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction', 'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points', xytext=(0, 25), textcoords='offset points',
va='center', ha='center' va='center', ha='center'
) )
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
ax.axis['yzero'].set_zorder(9)
ax.axis['yzero'].major_ticklabels.set_ha('left')
ax.axis['yzero'].major_ticklabels.set_va('center')
# Add a single tick on 0
ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0])
ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0])
ax.annotate(
'0', (0, 0),
xytext=(15, -10), textcoords='offset points',
va='center', ha='center'
)
ax.grid() ax.grid()
# Plot decision regions # Plot decision regions
@ -105,8 +85,8 @@ grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
ax.imshow( ax.imshow(
grid_images, grid_images,
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent), extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
aspect="auto", aspect='auto',
origin="lower", origin='lower',
cmap=color_map, cmap=color_map,
norm=color_norm, norm=color_norm,
alpha=0.1 alpha=0.1
@ -116,26 +96,21 @@ ax.imshow(
ax.scatter( ax.scatter(
*zip(*encoded_vectors.tolist()), *zip(*encoded_vectors.tolist()),
zorder=10, zorder=10,
s=60,
c=range(len(encoded_vectors)), c=range(len(encoded_vectors)),
edgecolor='black', edgecolor='black',
cmap=color_map, cmap=color_map,
norm=color_norm, norm=color_norm,
) )
for row in range(order):
ax.annotate(
row + 1, encoded_vectors[row],
xytext=(5, 5), textcoords='offset points',
backgroundcolor='white', zorder=9
)
# Plot noise # Plot noise
noisy_count = 1000 noisy_count = 1000
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1)) noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
ax.scatter( ax.scatter(
*zip(*noisy_vectors.tolist()), *zip(*noisy_vectors.tolist()),
s=1, marker='.',
s=5,
c=list(range(len(encoded_vectors))) * noisy_count, c=list(range(len(encoded_vectors))) * noisy_count,
cmap=color_map, cmap=color_map,
norm=color_norm, norm=color_norm,