Add convergence criterion plot during train

This commit is contained in:
Mattéo Delabre 2019-12-15 19:42:50 -05:00
parent 3c199dfc41
commit 59d7adf6bd
Signed by: matteo
GPG Key ID: AE3FBD02DC583ABB
3 changed files with 150 additions and 20 deletions

View File

@ -1,6 +1,8 @@
import torch.nn as nn
import torch
from .GaussianChannel import GaussianChannel
from .NormalizePower import NormalizePower
from . import util
class ConstellationNet(nn.Module):
@ -30,6 +32,7 @@ class ConstellationNet(nn.Module):
encoder and decoder network.
"""
super().__init__()
self.order = order
# Build the encoder network taking a one-hot encoded message as input
# and outputting an I/Q vector. The network additionally uses hidden
@ -78,3 +81,18 @@ class ConstellationNet(nn.Module):
symbol = self.encoder(x)
noisy_symbol = self.channel(symbol)
return self.decoder(noisy_symbol)
def get_constellation(self):
"""
Extract symbol constellation out of the trained encoder.
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
vector that is the result of encoding the nᵗʰ message.
"""
with torch.no_grad():
return self.encoder(
util.messages_to_onehot(
torch.arange(0, self.order),
self.order
)
)

View File

@ -1,4 +1,5 @@
import torch
import matplotlib
def get_random_messages(count, order):
@ -34,3 +35,60 @@ def messages_to_onehot(messages, order):
])
"""
return torch.nn.functional.one_hot(messages, num_classes=order).float()
def plot_constellation(ax, constellation):
ax.grid()
order = len(constellation)
color_map = matplotlib.cm.Dark2
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
# Extend axes symmetrically around zero so that they fit data
axis_extent = max(
abs(constellation.min()),
abs(constellation.max())
) * 1.05
ax.set_xlim(-axis_extent, axis_extent)
ax.set_ylim(-axis_extent, axis_extent)
# Hide borders but keep ticks
for direction in ['left', 'bottom', 'right', 'top']:
ax.axis[direction].line.set_color('#00000000')
# Show zero-centered axes without ticks
for direction in ['xzero', 'yzero']:
axis = ax.axis[direction]
axis.set_visible(True)
axis.set_axisline_style('-|>')
axis.major_ticklabels.set_visible(False)
# Add axis names
ax.annotate(
'I', (1, 0.5), xycoords='axes fraction',
xytext=(25, 0), textcoords='offset points',
va='center', ha='right'
)
ax.annotate(
'Q', (0.5, 1), xycoords='axes fraction',
xytext=(0, 25), textcoords='offset points',
va='center', ha='center'
)
ax.scatter(
*zip(*constellation.tolist()),
zorder=10,
s=60,
c=range(len(constellation)),
edgecolor='black',
cmap=color_map,
norm=color_norm,
)
# Plot center
center = constellation.sum(dim=0) / order
ax.scatter(
center[0], center[1],
marker='X',
)

View File

@ -1,60 +1,114 @@
import constellation
from constellation import util
import torch
from matplotlib import pyplot
import matplotlib
from mpl_toolkits.axisartist.axislines import SubplotZero
import time
# Number of symbols to learn
order = 4
# Number of epochs
num_epochs = 10000
# Number of batches to skip between every loss report
loss_report_batch_skip = 500
# Number of epochs to skip between every loss report
loss_report_epoch_skip = 500
# Size of batches during coarse optimization (small batches)
coarse_batch_size = 8
# Size of batches during fine optimization (large batches)
fine_batch_size = 2048
# File in which the trained model is saved
output_file = 'output/constellation-order-{}.pth'.format(order)
###
# Setup plot for showing training progress
fig = pyplot.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
pyplot.show(block=False)
# Train the model with random data
model = constellation.ConstellationNet(
order=order,
encoder_layers_sizes=(4,),
decoder_layers_sizes=(4,),
encoder_layers_sizes=(8,),
decoder_layers_sizes=(8,),
channel_model=constellation.GaussianChannel()
)
# Train the model with random data
model.train()
print('Starting training with {} epochs\n'.format(num_epochs))
print('Starting training\n')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Accumulated loss for last batches
running_loss = 0
prev_epoch_size_multiple = 0
for epoch in range(num_epochs):
epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048
# True in the first training phase where small batches are used, and false in
# the second phase where point positions are refined using large batches
is_coarse_optim = True
if epoch_size_multiple != prev_epoch_size_multiple:
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
prev_epoch_size_multiple = epoch_size_multiple
# Current batch index
batch = 1
# Current batch size
batch_size = coarse_batch_size
# List of training examples (not shuffled)
classes_ordered = torch.arange(order).repeat(batch_size)
# Constellation from the previous training batch
prev_constellation = model.get_constellation()
total_change = float('inf')
while True:
# Shuffle training data and convert to one-hot encoding
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
# Perform training step for current batch
model.train()
optimizer.zero_grad()
predictions = model(onehot_dataset)
loss = criterion(predictions, classes_dataset)
loss.backward()
optimizer.step()
# Report loss
# Check for convergence
model.eval()
constellation = model.get_constellation()
total_change = (constellation - prev_constellation).abs().sum()
prev_constellation = constellation
if is_coarse_optim:
if total_change < 1e-5:
print('Changing to fine optimization')
is_coarse_optim = False
batch_size = fine_batch_size
classes_ordered = torch.arange(order).repeat(batch_size)
elif total_change < 1e-5:
break
# Report loss and update figure (if applicable)
running_loss += loss.item()
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
print('Loss is {}'.format(running_loss / loss_report_epoch_skip))
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
ax.clear()
util.plot_constellation(ax, constellation)
fig.canvas.draw()
pyplot.pause(1e-17)
time.sleep(0.1)
print('Batch #{} (size {})'.format(batch + 1, batch_size))
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
print('\tChange is {}\n'.format(total_change))
running_loss = 0
batch += 1
print('\nFinished training\n')
# Print some examples of reconstruction