Change training strategy to adaptive learning rate
This commit is contained in:
parent
3f2c6d18a3
commit
8fa6b46ca8
8
plot.py
8
plot.py
|
@ -6,7 +6,7 @@ import matplotlib
|
||||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
|
|
||||||
# Number learned symbols
|
# Number learned symbols
|
||||||
order = 4
|
order = 16
|
||||||
|
|
||||||
# File in which the trained model is saved
|
# File in which the trained model is saved
|
||||||
input_file = 'output/constellation-order-{}.pth'.format(order)
|
input_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
|
@ -17,8 +17,8 @@ color_map = matplotlib.cm.Dark2
|
||||||
# Restore model from file
|
# Restore model from file
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(8,),
|
encoder_layers_sizes=(8, 4),
|
||||||
decoder_layers_sizes=(8,),
|
decoder_layers_sizes=(4, 8),
|
||||||
channel_model=constellation.GaussianChannel()
|
channel_model=constellation.GaussianChannel()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ constellation = model.get_constellation()
|
||||||
util.plot_constellation(
|
util.plot_constellation(
|
||||||
ax, constellation,
|
ax, constellation,
|
||||||
model.channel, model.decoder,
|
model.channel, model.decoder,
|
||||||
grid_step=0.001, noise_samples=2500
|
grid_step=0.001, noise_samples=0
|
||||||
)
|
)
|
||||||
|
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
|
|
64
train.py
64
train.py
|
@ -7,16 +7,13 @@ from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 4
|
order = 16
|
||||||
|
|
||||||
# Number of batches to skip between every loss report
|
# Number of batches to skip between every loss report
|
||||||
loss_report_batch_skip = 500
|
loss_report_batch_skip = 50
|
||||||
|
|
||||||
# Size of batches during coarse optimization (small batches)
|
# Size of batches
|
||||||
coarse_batch_size = 8
|
batch_size = 32
|
||||||
|
|
||||||
# Size of batches during fine optimization (large batches)
|
|
||||||
fine_batch_size = 2048
|
|
||||||
|
|
||||||
# File in which the trained model is saved
|
# File in which the trained model is saved
|
||||||
output_file = 'output/constellation-order-{}.pth'.format(order)
|
output_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
|
@ -33,15 +30,15 @@ pyplot.show(block=False)
|
||||||
# Train the model with random data
|
# Train the model with random data
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(8,),
|
encoder_layers_sizes=(8, 4),
|
||||||
decoder_layers_sizes=(8,),
|
decoder_layers_sizes=(4, 8),
|
||||||
channel_model=constellation.GaussianChannel()
|
channel_model=constellation.GaussianChannel()
|
||||||
)
|
)
|
||||||
|
|
||||||
print('Starting training\n')
|
print('Starting training\n')
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
# Current batch index
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
batch = 0
|
||||||
|
|
||||||
# Accumulated loss for last batches
|
# Accumulated loss for last batches
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
@ -50,12 +47,6 @@ running_loss = 0
|
||||||
# the second phase where point positions are refined using large batches
|
# the second phase where point positions are refined using large batches
|
||||||
is_coarse_optim = True
|
is_coarse_optim = True
|
||||||
|
|
||||||
# Current batch index
|
|
||||||
batch = 1
|
|
||||||
|
|
||||||
# Current batch size
|
|
||||||
batch_size = coarse_batch_size
|
|
||||||
|
|
||||||
# List of training examples (not shuffled)
|
# List of training examples (not shuffled)
|
||||||
classes_ordered = torch.arange(order).repeat(batch_size)
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||||
|
|
||||||
|
@ -63,7 +54,18 @@ classes_ordered = torch.arange(order).repeat(batch_size)
|
||||||
prev_constellation = model.get_constellation()
|
prev_constellation = model.get_constellation()
|
||||||
total_change = float('inf')
|
total_change = float('inf')
|
||||||
|
|
||||||
while True:
|
# Optimizer settings
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, verbose=True,
|
||||||
|
factor=0.25,
|
||||||
|
patience=100,
|
||||||
|
cooldown=50,
|
||||||
|
threshold=1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
while total_change >= 1e-4:
|
||||||
# Shuffle training data and convert to one-hot encoding
|
# Shuffle training data and convert to one-hot encoding
|
||||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||||
|
@ -76,38 +78,34 @@ while True:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# Update learning rate scheduler
|
||||||
|
scheduler.step(loss)
|
||||||
|
|
||||||
# Check for convergence
|
# Check for convergence
|
||||||
model.eval()
|
model.eval()
|
||||||
constellation = model.get_constellation()
|
constellation = model.get_constellation()
|
||||||
total_change = (constellation - prev_constellation).abs().sum()
|
total_change = (constellation - prev_constellation).abs().sum()
|
||||||
prev_constellation = constellation
|
prev_constellation = constellation
|
||||||
|
|
||||||
if is_coarse_optim:
|
# Report loss
|
||||||
if total_change < 1e-5:
|
|
||||||
print('Changing to fine optimization')
|
|
||||||
is_coarse_optim = False
|
|
||||||
batch_size = fine_batch_size
|
|
||||||
classes_ordered = torch.arange(order).repeat(batch_size)
|
|
||||||
elif total_change < 1e-5:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Report loss and update figure (if applicable)
|
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
||||||
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
||||||
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
print('Batch #{}'.format(batch + 1))
|
||||||
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
||||||
print('\tChange is {}\n'.format(total_change))
|
print('\tChange is {}\n'.format(total_change))
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
|
||||||
|
# Update figure with current encoding
|
||||||
ax.clear()
|
ax.clear()
|
||||||
util.plot_constellation(
|
util.plot_constellation(
|
||||||
ax, constellation,
|
ax, constellation,
|
||||||
model.channel, model.decoder
|
model.channel, model.decoder,
|
||||||
|
noise_samples=0
|
||||||
)
|
)
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
pyplot.pause(1e-17)
|
fig.canvas.flush_events()
|
||||||
|
|
||||||
running_loss = 0
|
|
||||||
|
|
||||||
batch += 1
|
batch += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue