constellationnet/plot.py

27 lines
597 B
Python
Raw Normal View History

2019-12-13 20:17:57 +00:00
import constellation
from constellation import util
import torch
from matplotlib import pyplot
# Number learned symbols
order = 4
# File in which the trained model is saved
input_file = 'output/constellation-net.tc'
model = constellation.ConstellationNet(order=order)
model.load_state_dict(torch.load(input_file))
# Compute encoded vectors
with torch.no_grad():
encoded_vectors = model.encoder(
util.messages_to_onehot(
torch.arange(0, order),
order
)
).tolist()
fig, axis = pyplot.subplots()
axis.scatter(*zip(*encoded_vectors))
pyplot.show()