diff --git a/constellation/ConstellationNet.py b/constellation/ConstellationNet.py index 7f984ff..6bacd07 100644 --- a/constellation/ConstellationNet.py +++ b/constellation/ConstellationNet.py @@ -36,14 +36,10 @@ class ConstellationNet(nn.Module): for layer_size in encoder_layers_sizes: encoder_layers.append(nn.Linear(prev_layer_size, layer_size)) - encoder_layers.append(nn.ReLU()) + encoder_layers.append(nn.SELU()) prev_layer_size = layer_size - encoder_layers += [ - nn.Linear(prev_layer_size, 2), - nn.ReLU(), - nn.BatchNorm1d(2), - ] + encoder_layers.append(nn.Linear(prev_layer_size, 2)) self.encoder = nn.Sequential(*encoder_layers) self.channel = channel_model @@ -57,7 +53,7 @@ class ConstellationNet(nn.Module): for layer_size in decoder_layers_sizes: decoder_layers.append(nn.Linear(prev_layer_size, layer_size)) - decoder_layers.append(nn.ReLU()) + decoder_layers.append(nn.SELU()) prev_layer_size = layer_size # Softmax is not used at the end of the network because the