2019-12-13 20:17:57 +00:00
|
|
|
import constellation
|
|
|
|
from constellation import util
|
2019-12-13 17:11:09 +00:00
|
|
|
import torch
|
2019-12-16 00:42:50 +00:00
|
|
|
from matplotlib import pyplot
|
|
|
|
import matplotlib
|
|
|
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
|
|
|
import time
|
2019-12-13 17:11:09 +00:00
|
|
|
|
|
|
|
# Number of symbols to learn
|
|
|
|
order = 4
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Number of batches to skip between every loss report
|
|
|
|
loss_report_batch_skip = 500
|
2019-12-13 17:11:09 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Size of batches during coarse optimization (small batches)
|
|
|
|
coarse_batch_size = 8
|
|
|
|
|
|
|
|
# Size of batches during fine optimization (large batches)
|
|
|
|
fine_batch_size = 2048
|
2019-12-13 17:11:09 +00:00
|
|
|
|
2019-12-13 20:17:57 +00:00
|
|
|
# File in which the trained model is saved
|
2019-12-15 04:04:35 +00:00
|
|
|
output_file = 'output/constellation-order-{}.pth'.format(order)
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
###
|
|
|
|
|
|
|
|
# Setup plot for showing training progress
|
|
|
|
fig = pyplot.figure()
|
|
|
|
ax = SubplotZero(fig, 111)
|
|
|
|
fig.add_subplot(ax)
|
|
|
|
|
|
|
|
pyplot.show(block=False)
|
|
|
|
|
|
|
|
# Train the model with random data
|
2019-12-15 04:04:35 +00:00
|
|
|
model = constellation.ConstellationNet(
|
|
|
|
order=order,
|
2019-12-16 00:42:50 +00:00
|
|
|
encoder_layers_sizes=(8,),
|
|
|
|
decoder_layers_sizes=(8,),
|
2019-12-15 04:04:35 +00:00
|
|
|
channel_model=constellation.GaussianChannel()
|
|
|
|
)
|
2019-12-13 22:10:40 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
print('Starting training\n')
|
2019-12-13 20:17:57 +00:00
|
|
|
|
2019-12-13 17:11:09 +00:00
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
2019-12-16 00:42:50 +00:00
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
2019-12-13 17:11:09 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Accumulated loss for last batches
|
2019-12-13 17:11:09 +00:00
|
|
|
running_loss = 0
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# True in the first training phase where small batches are used, and false in
|
|
|
|
# the second phase where point positions are refined using large batches
|
|
|
|
is_coarse_optim = True
|
2019-12-15 14:42:48 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Current batch index
|
|
|
|
batch = 1
|
2019-12-15 14:42:48 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Current batch size
|
|
|
|
batch_size = coarse_batch_size
|
|
|
|
|
|
|
|
# List of training examples (not shuffled)
|
|
|
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
|
|
|
|
|
|
|
# Constellation from the previous training batch
|
|
|
|
prev_constellation = model.get_constellation()
|
|
|
|
total_change = float('inf')
|
|
|
|
|
|
|
|
while True:
|
|
|
|
# Shuffle training data and convert to one-hot encoding
|
2019-12-15 05:59:10 +00:00
|
|
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
2019-12-13 20:17:57 +00:00
|
|
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
2019-12-13 17:11:09 +00:00
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Perform training step for current batch
|
|
|
|
model.train()
|
2019-12-13 17:11:09 +00:00
|
|
|
optimizer.zero_grad()
|
|
|
|
predictions = model(onehot_dataset)
|
|
|
|
loss = criterion(predictions, classes_dataset)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
# Check for convergence
|
|
|
|
model.eval()
|
|
|
|
constellation = model.get_constellation()
|
|
|
|
total_change = (constellation - prev_constellation).abs().sum()
|
|
|
|
prev_constellation = constellation
|
|
|
|
|
|
|
|
if is_coarse_optim:
|
|
|
|
if total_change < 1e-5:
|
|
|
|
print('Changing to fine optimization')
|
|
|
|
is_coarse_optim = False
|
|
|
|
batch_size = fine_batch_size
|
|
|
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
|
|
|
elif total_change < 1e-5:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Report loss and update figure (if applicable)
|
2019-12-13 17:11:09 +00:00
|
|
|
running_loss += loss.item()
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
|
|
|
ax.clear()
|
|
|
|
util.plot_constellation(ax, constellation)
|
|
|
|
fig.canvas.draw()
|
|
|
|
pyplot.pause(1e-17)
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
|
|
|
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
|
|
|
print('\tChange is {}\n'.format(total_change))
|
|
|
|
|
2019-12-13 17:11:09 +00:00
|
|
|
running_loss = 0
|
|
|
|
|
2019-12-16 00:42:50 +00:00
|
|
|
batch += 1
|
|
|
|
|
2019-12-13 20:17:57 +00:00
|
|
|
print('\nFinished training\n')
|
|
|
|
|
|
|
|
# Print some examples of reconstruction
|
2019-12-13 22:10:40 +00:00
|
|
|
model.eval()
|
|
|
|
print('Reconstruction examples:')
|
|
|
|
print('Input vector\t\t\tOutput vector after softmax')
|
2019-12-13 20:17:57 +00:00
|
|
|
|
2019-12-13 22:10:40 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
|
2019-12-13 20:17:57 +00:00
|
|
|
raw_output = model(onehot_example)
|
|
|
|
raw_output.required_grad = False
|
|
|
|
reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
|
|
|
|
|
2019-12-13 22:10:40 +00:00
|
|
|
for index in range(order):
|
2019-12-13 20:17:57 +00:00
|
|
|
print('{}\t\t{}'.format(
|
2019-12-13 22:10:40 +00:00
|
|
|
onehot_example[index].tolist(),
|
2019-12-13 20:17:57 +00:00
|
|
|
'[{}]'.format(', '.join(
|
|
|
|
'{:.5f}'.format(x)
|
2019-12-13 22:10:40 +00:00
|
|
|
for x in reconstructed_example[index].tolist()
|
2019-12-13 20:17:57 +00:00
|
|
|
))
|
|
|
|
))
|
2019-12-13 17:11:09 +00:00
|
|
|
|
2019-12-13 20:17:57 +00:00
|
|
|
print('\nSaving model as {}'.format(output_file))
|
|
|
|
torch.save(model.state_dict(), output_file)
|