37 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
		
		
			
		
	
	
			37 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | def get_random_messages(count, order): | ||
|  |     """
 | ||
|  |     Generate a list of messages. | ||
|  | 
 | ||
|  |     :param count: Number of messages to generate. | ||
|  |     :param order: Number of possible messages. | ||
|  |     :return: One-dimensional vector with each entry being the index of the | ||
|  |     generated message which is between 0, inclusive, and `order`, exclusive. | ||
|  | 
 | ||
|  |     >>> get_random_messages(5) | ||
|  |     torch.tensor([0, 2, 0, 3, 4]) | ||
|  |     """
 | ||
|  |     return torch.randint(0, order, (count,)) | ||
|  | 
 | ||
|  | 
 | ||
|  | def messages_to_onehot(messages, order): | ||
|  |     """
 | ||
|  |     Convert messages represented as indexes to one-hot encoding. | ||
|  | 
 | ||
|  |     :param messages: List of messages to convert. | ||
|  |     :param order: Number of possible messages. | ||
|  |     :return: One-hot encoded messages. | ||
|  | 
 | ||
|  |     >>> messages_to_onehot(torch.tensor([0, 2, 0, 3, 4])) | ||
|  |     torch.tensor([ | ||
|  |         [1., 0., 0., 0., 0.], | ||
|  |         [0., 0., 1., 0., 0.], | ||
|  |         [1., 0., 0., 0., 0.], | ||
|  |         [0., 0., 0., 1., 0.], | ||
|  |         [0., 0., 0., 0., 1.], | ||
|  |     ]) | ||
|  |     """
 | ||
|  |     return torch.nn.functional.one_hot(messages, num_classes=order).float() |