constellationnet/train.py

136 lines
3.6 KiB
Python
Raw Normal View History

2019-12-13 20:17:57 +00:00
import constellation
from constellation import util
2019-12-13 17:11:09 +00:00
import torch
from matplotlib import pyplot
from mpl_toolkits.axisartist.axislines import SubplotZero
2019-12-13 17:11:09 +00:00
torch.manual_seed(42)
2019-12-13 17:11:09 +00:00
# Number of symbols to learn
order = 16
2019-12-13 17:11:09 +00:00
# Number of batches to skip between every loss report
loss_report_batch_skip = 50
2019-12-13 17:11:09 +00:00
# Size of batches
batch_size = 32
2019-12-13 17:11:09 +00:00
2019-12-13 20:17:57 +00:00
# File in which the trained model is saved
2019-12-15 04:04:35 +00:00
output_file = 'output/constellation-order-{}.pth'.format(order)
###
# Setup plot for showing training progress
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
pyplot.show(block=False)
# Train the model with random data
2019-12-15 04:04:35 +00:00
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(8, 4),
decoder_layers_sizes=(4, 8),
2019-12-15 04:04:35 +00:00
channel_model=constellation.GaussianChannel()
)
2019-12-13 22:10:40 +00:00
print('Starting training\n')
2019-12-13 20:17:57 +00:00
# Current batch index
batch = 0
2019-12-13 17:11:09 +00:00
# Accumulated loss for last batches
2019-12-13 17:11:09 +00:00
running_loss = 0
# True in the first training phase where small batches are used, and false in
# the second phase where point positions are refined using large batches
is_coarse_optim = True
# List of training examples (not shuffled)
classes_ordered = torch.arange(order).repeat(batch_size)
# Constellation from the previous training batch
prev_constellation = model.get_constellation()
total_change = float('inf')
# Optimizer settings
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, verbose=True,
factor=0.25,
patience=100,
cooldown=50,
threshold=1e-8
)
while total_change >= 1e-4:
# Shuffle training data and convert to one-hot encoding
2019-12-15 05:59:10 +00:00
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
2019-12-13 20:17:57 +00:00
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
2019-12-13 17:11:09 +00:00
# Perform training step for current batch
model.train()
2019-12-13 17:11:09 +00:00
optimizer.zero_grad()
predictions = model(onehot_dataset)
loss = criterion(predictions, classes_dataset)
loss.backward()
optimizer.step()
# Update learning rate scheduler
scheduler.step(loss)
# Check for convergence
model.eval()
constellation = model.get_constellation()
total_change = (constellation - prev_constellation).abs().sum()
prev_constellation = constellation
# Report loss
2019-12-13 17:11:09 +00:00
running_loss += loss.item()
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
print('Batch #{}'.format(batch + 1))
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
print('\tChange is {}\n'.format(total_change))
2019-12-13 17:11:09 +00:00
running_loss = 0
# Update figure with current encoding
ax.clear()
util.plot_constellation(
ax, constellation,
model.channel, model.decoder,
noise_samples=0
)
fig.canvas.draw()
fig.canvas.flush_events()
batch += 1
2019-12-13 20:17:57 +00:00
print('\nFinished training\n')
# Print some examples of reconstruction
2019-12-13 22:10:40 +00:00
model.eval()
print('Reconstruction examples:')
print('Input vector\t\t\tOutput vector after softmax')
2019-12-13 20:17:57 +00:00
2019-12-13 22:10:40 +00:00
with torch.no_grad():
onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
2019-12-13 20:17:57 +00:00
raw_output = model(onehot_example)
raw_output.required_grad = False
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
2019-12-13 22:10:40 +00:00
for index in range(order):
2019-12-13 20:17:57 +00:00
print('{}\t\t{}'.format(
2019-12-13 22:10:40 +00:00
onehot_example[index].tolist(),
2019-12-13 20:17:57 +00:00
'[{}]'.format(', '.join(
'{:.5f}'.format(x)
2019-12-13 22:10:40 +00:00
for x in reconstructed_example[index].tolist()
2019-12-13 20:17:57 +00:00
))
))
2019-12-13 17:11:09 +00:00
2019-12-13 20:17:57 +00:00
print('\nSaving model as {}'.format(output_file))
torch.save(model.state_dict(), output_file)