Simplify arguments for ConstellationNet
This commit is contained in:
		
							parent
							
								
									a7e9dd2230
								
							
						
					
					
						commit
						d8a140d793
					
				|  | @ -13,9 +13,8 @@ class ConstellationNet(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         order=2, |         order=2, | ||||||
|         encoder_layers_sizes=(), |         encoder_layers=(), | ||||||
|         decoder_layers_sizes=(), |         decoder_layers=(), | ||||||
|         channel_model=GaussianChannel() |  | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Create an autoencoder. |         Create an autoencoder. | ||||||
|  | @ -23,56 +22,54 @@ class ConstellationNet(nn.Module): | ||||||
|         :param order: Order of the constellation, i.e. the number of messages |         :param order: Order of the constellation, i.e. the number of messages | ||||||
|         that are to be transmitted, or equivalently the number of symbols whose |         that are to be transmitted, or equivalently the number of symbols whose | ||||||
|         placements in the constellation have to be learned. |         placements in the constellation have to be learned. | ||||||
|         :param encoder_layers_sizes: Shape of the encoder’s hidden layers. The |         :param encoder_layers: Shape of the encoder’s hidden layers. The | ||||||
|         size of this sequence is the number of hidden layers, with each element |         size of this sequence is the number of hidden layers, with each element | ||||||
|         being a number which specifies the number of neurons in its channel. |         being a number which specifies the number of neurons in its channel. | ||||||
|         :param decoder_layers_sizes: Shape of the decoder’s hidden layers. Uses |         :param decoder_layers: Shape of the decoder’s hidden layers. Uses | ||||||
|         the same convention as `encoder_layers_sizes` above. |         the same convention as `encoder_layers_sizes` above. | ||||||
|         :param channel_model: Instance of the channel model to use between the |  | ||||||
|         encoder and decoder network. |  | ||||||
|         """ |         """ | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.order = order |         self.order = order | ||||||
| 
 | 
 | ||||||
|         # Build the encoder network taking a one-hot encoded message as input |         # Build the encoder network taking a one-hot encoded message as input | ||||||
|         # and outputting an I/Q vector. The network additionally uses hidden |         # and outputting an I/Q vector. The network additionally uses hidden | ||||||
|         # layers as specified in `encoder_layers_sizes` |         # layers as specified in `encoder_layers` | ||||||
|         prev_layer_size = order |         prev_layer_size = order | ||||||
|         encoder_layers = [] |         encoder_layers_list = [] | ||||||
| 
 | 
 | ||||||
|         for layer_size in encoder_layers_sizes: |         for layer_size in encoder_layers: | ||||||
|             encoder_layers.append(nn.Linear(prev_layer_size, layer_size)) |             encoder_layers_list.append(nn.Linear(prev_layer_size, layer_size)) | ||||||
|             encoder_layers.append(nn.ReLU()) |             encoder_layers_list.append(nn.ReLU()) | ||||||
|             encoder_layers.append(nn.BatchNorm1d(layer_size)) |             encoder_layers_list.append(nn.BatchNorm1d(layer_size)) | ||||||
|             prev_layer_size = layer_size |             prev_layer_size = layer_size | ||||||
| 
 | 
 | ||||||
|         encoder_layers += [ |         encoder_layers_list += [ | ||||||
|             nn.Linear(prev_layer_size, 2), |             nn.Linear(prev_layer_size, 2), | ||||||
|             NormalizePower(), |             NormalizePower(), | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|         self.encoder = nn.Sequential(*encoder_layers) |         self.encoder = nn.Sequential(*encoder_layers_list) | ||||||
|         self.channel = channel_model |         self.channel = GaussianChannel() | ||||||
| 
 | 
 | ||||||
|         # Build the decoder network taking the noisy I/Q vector received from |         # Build the decoder network taking the noisy I/Q vector received from | ||||||
|         # the channel as input and outputting a probability vector for each |         # the channel as input and outputting a probability vector for each | ||||||
|         # original message. The network additionally uses hidden layers as |         # original message. The network additionally uses hidden layers as | ||||||
|         # specified in `decoder_layers_sizes` |         # specified in `decoder_layers` | ||||||
|         prev_layer_size = 2 |         prev_layer_size = 2 | ||||||
|         decoder_layers = [] |         decoder_layers_list = [] | ||||||
| 
 | 
 | ||||||
|         for layer_size in decoder_layers_sizes: |         for layer_size in decoder_layers: | ||||||
|             decoder_layers.append(nn.Linear(prev_layer_size, layer_size)) |             decoder_layers_list.append(nn.Linear(prev_layer_size, layer_size)) | ||||||
|             encoder_layers.append(nn.ReLU()) |             encoder_layers_list.append(nn.ReLU()) | ||||||
|             decoder_layers.append(nn.BatchNorm1d(layer_size)) |             decoder_layers_list.append(nn.BatchNorm1d(layer_size)) | ||||||
|             prev_layer_size = layer_size |             prev_layer_size = layer_size | ||||||
| 
 | 
 | ||||||
|         # Softmax is not used at the end of the network because the |         # Softmax is not used at the end of the network because the | ||||||
|         # CrossEntropyLoss criterion is used for training, which includes |         # CrossEntropyLoss criterion is used for training, which includes | ||||||
|         # LogSoftmax |         # LogSoftmax | ||||||
|         decoder_layers.append(nn.Linear(prev_layer_size, order)) |         decoder_layers_list.append(nn.Linear(prev_layer_size, order)) | ||||||
| 
 | 
 | ||||||
|         self.decoder = nn.Sequential(*decoder_layers) |         self.decoder = nn.Sequential(*decoder_layers_list) | ||||||
| 
 | 
 | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         """ |         """ | ||||||
|  | @ -88,7 +85,7 @@ class ConstellationNet(nn.Module): | ||||||
| 
 | 
 | ||||||
|     def get_constellation(self): |     def get_constellation(self): | ||||||
|         """ |         """ | ||||||
|         Extract symbol constellation out of the trained encoder. |         Extract the symbol constellation out of the trained encoder. | ||||||
| 
 | 
 | ||||||
|         :return: Matrix containing `order` rows with the nᵗʰ one being the I/Q |         :return: Matrix containing `order` rows with the nᵗʰ one being the I/Q | ||||||
|         vector that is the result of encoding the nᵗʰ message. |         vector that is the result of encoding the nᵗʰ message. | ||||||
|  |  | ||||||
|  | @ -4,6 +4,10 @@ import math | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class GaussianChannel(nn.Module): | class GaussianChannel(nn.Module): | ||||||
|  |     """ | ||||||
|  |     Simulated communication channel that assumes a Gaussian noise model for | ||||||
|  |     taking in account interference. | ||||||
|  |     """ | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super().__init__() |         super().__init__() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								train.py
								
								
								
								
							
							
						
						
									
										2
									
								
								train.py
								
								
								
								
							|  | @ -5,7 +5,7 @@ from matplotlib import pyplot | ||||||
| from mpl_toolkits.axisartist.axislines import SubplotZero | from mpl_toolkits.axisartist.axislines import SubplotZero | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| torch.manual_seed(42) | torch.manual_seed(57) | ||||||
| 
 | 
 | ||||||
| # Number of symbols to learn | # Number of symbols to learn | ||||||
| order = 16 | order = 16 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue