Use best parameters as found by experimentation
This commit is contained in:
parent
5cb087d971
commit
0769a61fcf
13
train.py
13
train.py
|
@ -4,16 +4,19 @@ import torch
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(57)
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 16
|
order = 16
|
||||||
|
|
||||||
|
# Initial value for the learning rate
|
||||||
|
initial_learning_rate = 0.1
|
||||||
|
|
||||||
# Number of batches to skip between every loss report
|
# Number of batches to skip between every loss report
|
||||||
loss_report_batch_skip = 50
|
loss_report_batch_skip = 50
|
||||||
|
|
||||||
# Size of batches
|
# Size of batches
|
||||||
batch_size = 32
|
batch_size = 2048
|
||||||
|
|
||||||
# File in which the trained model is saved
|
# File in which the trained model is saved
|
||||||
output_file = 'output/constellation-order-{}.pth'.format(order)
|
output_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
|
@ -30,8 +33,8 @@ pyplot.show(block=False)
|
||||||
# Train the model with random data
|
# Train the model with random data
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(8, 4),
|
encoder_layers_sizes=(8, 4,),
|
||||||
decoder_layers_sizes=(4, 8),
|
decoder_layers_sizes=(4, 8,),
|
||||||
channel_model=constellation.GaussianChannel()
|
channel_model=constellation.GaussianChannel()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,7 +55,7 @@ total_change = float('inf')
|
||||||
|
|
||||||
# Optimizer settings
|
# Optimizer settings
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer, verbose=True,
|
optimizer, verbose=True,
|
||||||
factor=0.25,
|
factor=0.25,
|
||||||
|
|
Loading…
Reference in New Issue