import torch def get_random_messages(count, order): """ Generate a list of messages. :param count: Number of messages to generate. :param order: Number of possible messages. :return: One-dimensional vector with each entry being the index of the generated message which is between 0, inclusive, and `order`, exclusive. >>> get_random_messages(5) torch.tensor([0, 2, 0, 3, 4]) """ return torch.randint(0, order, (count,)) def messages_to_onehot(messages, order): """ Convert messages represented as indexes to one-hot encoding. :param messages: List of messages to convert. :param order: Number of possible messages. :return: One-hot encoded messages. >>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4])) torch.tensor([ [1., 0., 0., 0., 0.], [0., 0., 1., 0., 0.], [1., 0., 0., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], ]) """ return torch.nn.functional.one_hot(messages, num_classes=order).float()