Add convergence criterion plot during train
This commit is contained in:
parent
3c199dfc41
commit
59d7adf6bd
|
@ -1,6 +1,8 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from .GaussianChannel import GaussianChannel
|
||||
from .NormalizePower import NormalizePower
|
||||
from . import util
|
||||
|
||||
|
||||
class ConstellationNet(nn.Module):
|
||||
|
@ -30,6 +32,7 @@ class ConstellationNet(nn.Module):
|
|||
encoder and decoder network.
|
||||
"""
|
||||
super().__init__()
|
||||
self.order = order
|
||||
|
||||
# Build the encoder network taking a one-hot encoded message as input
|
||||
# and outputting an I/Q vector. The network additionally uses hidden
|
||||
|
@ -78,3 +81,18 @@ class ConstellationNet(nn.Module):
|
|||
symbol = self.encoder(x)
|
||||
noisy_symbol = self.channel(symbol)
|
||||
return self.decoder(noisy_symbol)
|
||||
|
||||
def get_constellation(self):
|
||||
"""
|
||||
Extract symbol constellation out of the trained encoder.
|
||||
|
||||
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
|
||||
vector that is the result of encoding the nᵗʰ message.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.encoder(
|
||||
util.messages_to_onehot(
|
||||
torch.arange(0, self.order),
|
||||
self.order
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import matplotlib
|
||||
|
||||
|
||||
def get_random_messages(count, order):
|
||||
|
@ -34,3 +35,60 @@ def messages_to_onehot(messages, order):
|
|||
])
|
||||
"""
|
||||
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
||||
|
||||
|
||||
def plot_constellation(ax, constellation):
|
||||
ax.grid()
|
||||
|
||||
order = len(constellation)
|
||||
color_map = matplotlib.cm.Dark2
|
||||
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
||||
|
||||
# Extend axes symmetrically around zero so that they fit data
|
||||
axis_extent = max(
|
||||
abs(constellation.min()),
|
||||
abs(constellation.max())
|
||||
) * 1.05
|
||||
ax.set_xlim(-axis_extent, axis_extent)
|
||||
ax.set_ylim(-axis_extent, axis_extent)
|
||||
|
||||
# Hide borders but keep ticks
|
||||
for direction in ['left', 'bottom', 'right', 'top']:
|
||||
ax.axis[direction].line.set_color('#00000000')
|
||||
|
||||
# Show zero-centered axes without ticks
|
||||
for direction in ['xzero', 'yzero']:
|
||||
axis = ax.axis[direction]
|
||||
axis.set_visible(True)
|
||||
axis.set_axisline_style('-|>')
|
||||
axis.major_ticklabels.set_visible(False)
|
||||
|
||||
# Add axis names
|
||||
ax.annotate(
|
||||
'I', (1, 0.5), xycoords='axes fraction',
|
||||
xytext=(25, 0), textcoords='offset points',
|
||||
va='center', ha='right'
|
||||
)
|
||||
|
||||
ax.annotate(
|
||||
'Q', (0.5, 1), xycoords='axes fraction',
|
||||
xytext=(0, 25), textcoords='offset points',
|
||||
va='center', ha='center'
|
||||
)
|
||||
|
||||
ax.scatter(
|
||||
*zip(*constellation.tolist()),
|
||||
zorder=10,
|
||||
s=60,
|
||||
c=range(len(constellation)),
|
||||
edgecolor='black',
|
||||
cmap=color_map,
|
||||
norm=color_norm,
|
||||
)
|
||||
|
||||
# Plot center
|
||||
center = constellation.sum(dim=0) / order
|
||||
ax.scatter(
|
||||
center[0], center[1],
|
||||
marker='X',
|
||||
)
|
||||
|
|
94
train.py
94
train.py
|
@ -1,60 +1,114 @@
|
|||
import constellation
|
||||
from constellation import util
|
||||
import torch
|
||||
from matplotlib import pyplot
|
||||
import matplotlib
|
||||
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||
import time
|
||||
|
||||
# Number of symbols to learn
|
||||
order = 4
|
||||
|
||||
# Number of epochs
|
||||
num_epochs = 10000
|
||||
# Number of batches to skip between every loss report
|
||||
loss_report_batch_skip = 500
|
||||
|
||||
# Number of epochs to skip between every loss report
|
||||
loss_report_epoch_skip = 500
|
||||
# Size of batches during coarse optimization (small batches)
|
||||
coarse_batch_size = 8
|
||||
|
||||
# Size of batches during fine optimization (large batches)
|
||||
fine_batch_size = 2048
|
||||
|
||||
# File in which the trained model is saved
|
||||
output_file = 'output/constellation-order-{}.pth'.format(order)
|
||||
|
||||
###
|
||||
|
||||
# Setup plot for showing training progress
|
||||
fig = pyplot.figure()
|
||||
ax = SubplotZero(fig, 111)
|
||||
fig.add_subplot(ax)
|
||||
|
||||
pyplot.show(block=False)
|
||||
|
||||
# Train the model with random data
|
||||
model = constellation.ConstellationNet(
|
||||
order=order,
|
||||
encoder_layers_sizes=(4,),
|
||||
decoder_layers_sizes=(4,),
|
||||
encoder_layers_sizes=(8,),
|
||||
decoder_layers_sizes=(8,),
|
||||
channel_model=constellation.GaussianChannel()
|
||||
)
|
||||
|
||||
# Train the model with random data
|
||||
model.train()
|
||||
print('Starting training with {} epochs\n'.format(num_epochs))
|
||||
print('Starting training\n')
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
# Accumulated loss for last batches
|
||||
running_loss = 0
|
||||
prev_epoch_size_multiple = 0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048
|
||||
# True in the first training phase where small batches are used, and false in
|
||||
# the second phase where point positions are refined using large batches
|
||||
is_coarse_optim = True
|
||||
|
||||
if epoch_size_multiple != prev_epoch_size_multiple:
|
||||
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
||||
prev_epoch_size_multiple = epoch_size_multiple
|
||||
# Current batch index
|
||||
batch = 1
|
||||
|
||||
# Current batch size
|
||||
batch_size = coarse_batch_size
|
||||
|
||||
# List of training examples (not shuffled)
|
||||
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||
|
||||
# Constellation from the previous training batch
|
||||
prev_constellation = model.get_constellation()
|
||||
total_change = float('inf')
|
||||
|
||||
while True:
|
||||
# Shuffle training data and convert to one-hot encoding
|
||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||
|
||||
# Perform training step for current batch
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
predictions = model(onehot_dataset)
|
||||
loss = criterion(predictions, classes_dataset)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Report loss
|
||||
# Check for convergence
|
||||
model.eval()
|
||||
constellation = model.get_constellation()
|
||||
total_change = (constellation - prev_constellation).abs().sum()
|
||||
prev_constellation = constellation
|
||||
|
||||
if is_coarse_optim:
|
||||
if total_change < 1e-5:
|
||||
print('Changing to fine optimization')
|
||||
is_coarse_optim = False
|
||||
batch_size = fine_batch_size
|
||||
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||
elif total_change < 1e-5:
|
||||
break
|
||||
|
||||
# Report loss and update figure (if applicable)
|
||||
running_loss += loss.item()
|
||||
|
||||
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
||||
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
||||
print('Loss is {}'.format(running_loss / loss_report_epoch_skip))
|
||||
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
||||
ax.clear()
|
||||
util.plot_constellation(ax, constellation)
|
||||
fig.canvas.draw()
|
||||
pyplot.pause(1e-17)
|
||||
time.sleep(0.1)
|
||||
|
||||
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
||||
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
||||
print('\tChange is {}\n'.format(total_change))
|
||||
|
||||
running_loss = 0
|
||||
|
||||
batch += 1
|
||||
|
||||
print('\nFinished training\n')
|
||||
|
||||
# Print some examples of reconstruction
|
||||
|
|
Loading…
Reference in New Issue