37 lines
1.0 KiB
Python
37 lines
1.0 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
def get_random_messages(count, order):
|
||
|
"""
|
||
|
Generate a list of messages.
|
||
|
|
||
|
:param count: Number of messages to generate.
|
||
|
:param order: Number of possible messages.
|
||
|
:return: One-dimensional vector with each entry being the index of the
|
||
|
generated message which is between 0, inclusive, and `order`, exclusive.
|
||
|
|
||
|
>>> get_random_messages(5)
|
||
|
torch.tensor([0, 2, 0, 3, 4])
|
||
|
"""
|
||
|
return torch.randint(0, order, (count,))
|
||
|
|
||
|
|
||
|
def messages_to_onehot(messages, order):
|
||
|
"""
|
||
|
Convert messages represented as indexes to one-hot encoding.
|
||
|
|
||
|
:param messages: List of messages to convert.
|
||
|
:param order: Number of possible messages.
|
||
|
:return: One-hot encoded messages.
|
||
|
|
||
|
>>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4]))
|
||
|
torch.tensor([
|
||
|
[1., 0., 0., 0., 0.],
|
||
|
[0., 0., 1., 0., 0.],
|
||
|
[1., 0., 0., 0., 0.],
|
||
|
[0., 0., 0., 1., 0.],
|
||
|
[0., 0., 0., 0., 1.],
|
||
|
])
|
||
|
"""
|
||
|
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|