27 lines
597 B
Python
27 lines
597 B
Python
|
import constellation
|
||
|
from constellation import util
|
||
|
import torch
|
||
|
from matplotlib import pyplot
|
||
|
|
||
|
# Number learned symbols
|
||
|
order = 4
|
||
|
|
||
|
# File in which the trained model is saved
|
||
|
input_file = 'output/constellation-net.tc'
|
||
|
|
||
|
model = constellation.ConstellationNet(order=order)
|
||
|
model.load_state_dict(torch.load(input_file))
|
||
|
|
||
|
# Compute encoded vectors
|
||
|
with torch.no_grad():
|
||
|
encoded_vectors = model.encoder(
|
||
|
util.messages_to_onehot(
|
||
|
torch.arange(0, order),
|
||
|
order
|
||
|
)
|
||
|
).tolist()
|
||
|
|
||
|
fig, axis = pyplot.subplots()
|
||
|
axis.scatter(*zip(*encoded_vectors))
|
||
|
pyplot.show()
|