Change training strategy to adaptive learning rate

This commit is contained in:
Mattéo Delabre 2019-12-16 02:30:05 -05:00
parent 3f2c6d18a3
commit 8fa6b46ca8
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
2 changed files with 39 additions and 41 deletions

View File

@ -6,7 +6,7 @@ import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero from mpl_toolkits.axisartist.axislines import SubplotZero
# Number learned symbols # Number learned symbols
order = 4 order = 16
# File in which the trained model is saved # File in which the trained model is saved
input_file = 'output/constellation-order-{}.pth'.format(order) input_file = 'output/constellation-order-{}.pth'.format(order)
@ -17,8 +17,8 @@ color_map = matplotlib.cm.Dark2
# Restore model from file # Restore model from file
model = constellation.ConstellationNet( model = constellation.ConstellationNet(
order=order, order=order,
encoder_layers_sizes=(8,), encoder_layers_sizes=(8, 4),
decoder_layers_sizes=(8,), decoder_layers_sizes=(4, 8),
channel_model=constellation.GaussianChannel() channel_model=constellation.GaussianChannel()
) )
@ -34,7 +34,7 @@ constellation = model.get_constellation()
util.plot_constellation( util.plot_constellation(
ax, constellation, ax, constellation,
model.channel, model.decoder, model.channel, model.decoder,
grid_step=0.001, noise_samples=2500 grid_step=0.001, noise_samples=0
) )
pyplot.show() pyplot.show()

View File

@ -7,16 +7,13 @@ from mpl_toolkits.axisartist.axislines import SubplotZero
torch.manual_seed(42) torch.manual_seed(42)
# Number of symbols to learn # Number of symbols to learn
order = 4 order = 16
# Number of batches to skip between every loss report # Number of batches to skip between every loss report
loss_report_batch_skip = 500 loss_report_batch_skip = 50
# Size of batches during coarse optimization (small batches) # Size of batches
coarse_batch_size = 8 batch_size = 32
# Size of batches during fine optimization (large batches)
fine_batch_size = 2048
# File in which the trained model is saved # File in which the trained model is saved
output_file = 'output/constellation-order-{}.pth'.format(order) output_file = 'output/constellation-order-{}.pth'.format(order)
@ -33,15 +30,15 @@ pyplot.show(block=False)
# Train the model with random data # Train the model with random data
model = constellation.ConstellationNet( model = constellation.ConstellationNet(
order=order, order=order,
encoder_layers_sizes=(8,), encoder_layers_sizes=(8, 4),
decoder_layers_sizes=(8,), decoder_layers_sizes=(4, 8),
channel_model=constellation.GaussianChannel() channel_model=constellation.GaussianChannel()
) )
print('Starting training\n') print('Starting training\n')
criterion = torch.nn.CrossEntropyLoss() # Current batch index
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) batch = 0
# Accumulated loss for last batches # Accumulated loss for last batches
running_loss = 0 running_loss = 0
@ -50,12 +47,6 @@ running_loss = 0
# the second phase where point positions are refined using large batches # the second phase where point positions are refined using large batches
is_coarse_optim = True is_coarse_optim = True
# Current batch index
batch = 1
# Current batch size
batch_size = coarse_batch_size
# List of training examples (not shuffled) # List of training examples (not shuffled)
classes_ordered = torch.arange(order).repeat(batch_size) classes_ordered = torch.arange(order).repeat(batch_size)
@ -63,7 +54,18 @@ classes_ordered = torch.arange(order).repeat(batch_size)
prev_constellation = model.get_constellation() prev_constellation = model.get_constellation()
total_change = float('inf') total_change = float('inf')
while True: # Optimizer settings
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, verbose=True,
factor=0.25,
patience=100,
cooldown=50,
threshold=1e-8
)
while total_change >= 1e-4:
# Shuffle training data and convert to one-hot encoding # Shuffle training data and convert to one-hot encoding
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order) onehot_dataset = util.messages_to_onehot(classes_dataset, order)
@ -76,38 +78,34 @@ while True:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Update learning rate scheduler
scheduler.step(loss)
# Check for convergence # Check for convergence
model.eval() model.eval()
constellation = model.get_constellation() constellation = model.get_constellation()
total_change = (constellation - prev_constellation).abs().sum() total_change = (constellation - prev_constellation).abs().sum()
prev_constellation = constellation prev_constellation = constellation
if is_coarse_optim: # Report loss
if total_change < 1e-5:
print('Changing to fine optimization')
is_coarse_optim = False
batch_size = fine_batch_size
classes_ordered = torch.arange(order).repeat(batch_size)
elif total_change < 1e-5:
break
# Report loss and update figure (if applicable)
running_loss += loss.item() running_loss += loss.item()
if batch % loss_report_batch_skip == loss_report_batch_skip - 1: if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
print('Batch #{} (size {})'.format(batch + 1, batch_size)) print('Batch #{}'.format(batch + 1))
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
print('\tChange is {}\n'.format(total_change)) print('\tChange is {}\n'.format(total_change))
running_loss = 0
# Update figure with current encoding
ax.clear() ax.clear()
util.plot_constellation( util.plot_constellation(
ax, constellation, ax, constellation,
model.channel, model.decoder model.channel, model.decoder,
noise_samples=0
) )
fig.canvas.draw() fig.canvas.draw()
pyplot.pause(1e-17) fig.canvas.flush_events()
running_loss = 0
batch += 1 batch += 1