Improve plot legibility
This commit is contained in:
parent
8a31b22b83
commit
b97ba61f42
49
plot.py
49
plot.py
|
@ -12,7 +12,7 @@ order = 4
|
|||
input_file = 'output/constellation-order-{}.pth'.format(order)
|
||||
|
||||
# Color map used for decision regions
|
||||
color_map = matplotlib.cm.Set1
|
||||
color_map = matplotlib.cm.Dark2
|
||||
|
||||
# Restore model from file
|
||||
model = constellation.ConstellationNet(
|
||||
|
@ -47,50 +47,30 @@ axis_extent = max(
|
|||
ax.set_xlim(-axis_extent, axis_extent)
|
||||
ax.set_ylim(-axis_extent, axis_extent)
|
||||
|
||||
# Hide borders
|
||||
# Hide borders but keep ticks
|
||||
for direction in ['left', 'bottom', 'right', 'top']:
|
||||
ax.axis[direction].set_visible(False)
|
||||
ax.axis[direction].line.set_color('#00000000')
|
||||
|
||||
# Show zero-centered axes
|
||||
# Show zero-centered axes without ticks
|
||||
for direction in ['xzero', 'yzero']:
|
||||
axis = ax.axis[direction]
|
||||
axis.set_visible(True)
|
||||
axis.set_axisline_style("-|>")
|
||||
axis.set_axisline_style('-|>')
|
||||
axis.major_ticklabels.set_visible(False)
|
||||
|
||||
# Configure axes ticks and labels
|
||||
# Add axis names
|
||||
ax.annotate(
|
||||
'I', (1, 0.5), xycoords='axes fraction',
|
||||
xytext=(25, 0), textcoords='offset points',
|
||||
va='center', ha='right'
|
||||
)
|
||||
|
||||
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
|
||||
ax.axis['xzero'].set_zorder(9)
|
||||
ax.axis['xzero'].major_ticklabels.set_ha('center')
|
||||
ax.axis['xzero'].major_ticklabels.set_va('top')
|
||||
|
||||
ax.annotate(
|
||||
'Q', (0.5, 1), xycoords='axes fraction',
|
||||
xytext=(0, 25), textcoords='offset points',
|
||||
va='center', ha='center'
|
||||
)
|
||||
|
||||
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
|
||||
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
|
||||
ax.axis['yzero'].set_zorder(9)
|
||||
ax.axis['yzero'].major_ticklabels.set_ha('left')
|
||||
ax.axis['yzero'].major_ticklabels.set_va('center')
|
||||
|
||||
# Add a single tick on 0
|
||||
ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0])
|
||||
ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0])
|
||||
|
||||
ax.annotate(
|
||||
'0', (0, 0),
|
||||
xytext=(15, -10), textcoords='offset points',
|
||||
va='center', ha='center'
|
||||
)
|
||||
|
||||
ax.grid()
|
||||
|
||||
# Plot decision regions
|
||||
|
@ -105,8 +85,8 @@ grid_images = model.decoder(torch.stack((grid_x, grid_y), dim=2)).argmax(dim=2)
|
|||
ax.imshow(
|
||||
grid_images,
|
||||
extent=(-regions_extent, regions_extent, -regions_extent, regions_extent),
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
alpha=0.1
|
||||
|
@ -116,26 +96,21 @@ ax.imshow(
|
|||
ax.scatter(
|
||||
*zip(*encoded_vectors.tolist()),
|
||||
zorder=10,
|
||||
s=60,
|
||||
c=range(len(encoded_vectors)),
|
||||
edgecolor='black',
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
)
|
||||
|
||||
for row in range(order):
|
||||
ax.annotate(
|
||||
row + 1, encoded_vectors[row],
|
||||
xytext=(5, 5), textcoords='offset points',
|
||||
backgroundcolor='white', zorder=9
|
||||
)
|
||||
|
||||
# Plot noise
|
||||
noisy_count = 1000
|
||||
noisy_vectors = model.channel(encoded_vectors.repeat(noisy_count, 1))
|
||||
|
||||
ax.scatter(
|
||||
*zip(*noisy_vectors.tolist()),
|
||||
s=1,
|
||||
marker='.',
|
||||
s=5,
|
||||
c=list(range(len(encoded_vectors))) * noisy_count,
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
|
|
Loading…
Reference in New Issue