Improve plot
This commit is contained in:
parent
49e63775dd
commit
a01d83f339
77
plot.py
77
plot.py
|
@ -2,12 +2,15 @@ import constellation
|
|||
from constellation import util
|
||||
import torch
|
||||
from matplotlib import pyplot
|
||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||
import math
|
||||
import numpy
|
||||
|
||||
# Number learned symbols
|
||||
order = 4
|
||||
|
||||
# File in which the trained model is saved
|
||||
input_file = 'output/constellation-net.tc'
|
||||
input_file = 'output/constellation-order-{}.tc'.format(order)
|
||||
|
||||
model = constellation.ConstellationNet(order=order)
|
||||
model.load_state_dict(torch.load(input_file))
|
||||
|
@ -19,8 +22,74 @@ with torch.no_grad():
|
|||
torch.arange(0, order),
|
||||
order
|
||||
)
|
||||
).tolist()
|
||||
)
|
||||
|
||||
fig = pyplot.figure()
|
||||
ax = SubplotZero(fig, 111)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
# Extend axes symmetrically around zero so that they fit data
|
||||
extent = max(
|
||||
abs(encoded_vectors.min()),
|
||||
abs(encoded_vectors.max())
|
||||
) * 1.05
|
||||
|
||||
ax.set_xlim(-extent, extent)
|
||||
ax.set_ylim(-extent, extent)
|
||||
|
||||
# Hide borders
|
||||
for direction in ['left', 'bottom', 'right', 'top']:
|
||||
ax.axis[direction].set_visible(False)
|
||||
|
||||
# Show zero-centered axes
|
||||
for direction in ['xzero', 'yzero']:
|
||||
axis = ax.axis[direction]
|
||||
axis.set_visible(True)
|
||||
axis.set_axisline_style("-|>")
|
||||
|
||||
# Configure axes ticks and labels
|
||||
ax.annotate(
|
||||
'I', (1, 0.5), xycoords='axes fraction',
|
||||
xytext=(25, 0), textcoords='offset points',
|
||||
va='center', ha='right'
|
||||
)
|
||||
|
||||
ax.axis['xzero'].major_ticklabels.set_backgroundcolor('white')
|
||||
ax.axis['xzero'].major_ticklabels.set_ha('center')
|
||||
ax.axis['xzero'].major_ticklabels.set_va('top')
|
||||
|
||||
ax.annotate(
|
||||
'Q', (0.5, 1), xycoords='axes fraction',
|
||||
xytext=(0, 25), textcoords='offset points',
|
||||
va='center', ha='center'
|
||||
)
|
||||
|
||||
ax.axis['yzero'].major_ticklabels.set_rotation(-90)
|
||||
ax.axis['yzero'].major_ticklabels.set_backgroundcolor('white')
|
||||
ax.axis['yzero'].major_ticklabels.set_ha('left')
|
||||
ax.axis['yzero'].major_ticklabels.set_va('center')
|
||||
|
||||
# Add a single tick on 0
|
||||
ax.set_xticks(ax.get_xticks()[ax.get_xticks() != 0])
|
||||
ax.set_yticks(ax.get_yticks()[ax.get_yticks() != 0])
|
||||
|
||||
ax.annotate(
|
||||
'0', (0, 0),
|
||||
xytext=(15, -10), textcoords='offset points',
|
||||
va='center', ha='center'
|
||||
)
|
||||
|
||||
ax.grid()
|
||||
|
||||
# Plot encoded vectors
|
||||
ax.scatter(*zip(*encoded_vectors.tolist()), zorder=10)
|
||||
|
||||
# Add index label for each vector
|
||||
for row in range(order):
|
||||
ax.annotate(
|
||||
row + 1, encoded_vectors[row],
|
||||
xytext=(5, 5), textcoords='offset points',
|
||||
backgroundcolor='white', zorder=9
|
||||
)
|
||||
|
||||
fig, axis = pyplot.subplots()
|
||||
axis.scatter(*zip(*encoded_vectors))
|
||||
pyplot.show()
|
||||
|
|
27
train.py
27
train.py
|
@ -15,11 +15,14 @@ num_epochs = 20000
|
|||
loss_report_epoch_skip = 500
|
||||
|
||||
# File in which the trained model is saved
|
||||
output_file = 'output/constellation-net.tc'
|
||||
|
||||
print('Starting training with {} epochs\n'.format(num_epochs))
|
||||
output_file = 'output/constellation-order-{}.tc'.format(order)
|
||||
|
||||
model = constellation.ConstellationNet(order=order)
|
||||
|
||||
# Train the model with random data
|
||||
model.train()
|
||||
print('Starting training with {} epochs\n'.format(num_epochs))
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
|
@ -46,24 +49,22 @@ for epoch in range(num_epochs):
|
|||
print('\nFinished training\n')
|
||||
|
||||
# Print some examples of reconstruction
|
||||
with torch.no_grad():
|
||||
num_examples = 5
|
||||
model.eval()
|
||||
print('Reconstruction examples:')
|
||||
print('Input vector\t\t\tOutput vector after softmax')
|
||||
|
||||
classes_example = util.get_random_messages(num_examples, order)
|
||||
onehot_example = util.messages_to_onehot(classes_example, order)
|
||||
with torch.no_grad():
|
||||
onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
|
||||
raw_output = model(onehot_example)
|
||||
raw_output.required_grad = False
|
||||
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
|
||||
|
||||
print('Reconstruction examples:')
|
||||
print('Input vector\t\t\tOutput vector after softmax')
|
||||
|
||||
for example_index in range(num_examples):
|
||||
for index in range(order):
|
||||
print('{}\t\t{}'.format(
|
||||
onehot_example[example_index].tolist(),
|
||||
onehot_example[index].tolist(),
|
||||
'[{}]'.format(', '.join(
|
||||
'{:.5f}'.format(x)
|
||||
for x in reconstructed_example[example_index].tolist()
|
||||
for x in reconstructed_example[index].tolist()
|
||||
))
|
||||
))
|
||||
|
||||
|
|
Loading…
Reference in New Issue