import constellation from constellation import util import torch from matplotlib import pyplot # Number learned symbols order = 4 # File in which the trained model is saved input_file = 'output/constellation-net.tc' model = constellation.ConstellationNet(order=order) model.load_state_dict(torch.load(input_file)) # Compute encoded vectors with torch.no_grad(): encoded_vectors = model.encoder( util.messages_to_onehot( torch.arange(0, order), order ) ).tolist() fig, axis = pyplot.subplots() axis.scatter(*zip(*encoded_vectors)) pyplot.show()