Add convergence criterion plot during train
This commit is contained in:
parent
3c199dfc41
commit
59d7adf6bd
|
@ -1,6 +1,8 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
from .GaussianChannel import GaussianChannel
|
from .GaussianChannel import GaussianChannel
|
||||||
from .NormalizePower import NormalizePower
|
from .NormalizePower import NormalizePower
|
||||||
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
class ConstellationNet(nn.Module):
|
class ConstellationNet(nn.Module):
|
||||||
|
@ -30,6 +32,7 @@ class ConstellationNet(nn.Module):
|
||||||
encoder and decoder network.
|
encoder and decoder network.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.order = order
|
||||||
|
|
||||||
# Build the encoder network taking a one-hot encoded message as input
|
# Build the encoder network taking a one-hot encoded message as input
|
||||||
# and outputting an I/Q vector. The network additionally uses hidden
|
# and outputting an I/Q vector. The network additionally uses hidden
|
||||||
|
@ -78,3 +81,18 @@ class ConstellationNet(nn.Module):
|
||||||
symbol = self.encoder(x)
|
symbol = self.encoder(x)
|
||||||
noisy_symbol = self.channel(symbol)
|
noisy_symbol = self.channel(symbol)
|
||||||
return self.decoder(noisy_symbol)
|
return self.decoder(noisy_symbol)
|
||||||
|
|
||||||
|
def get_constellation(self):
|
||||||
|
"""
|
||||||
|
Extract symbol constellation out of the trained encoder.
|
||||||
|
|
||||||
|
:return: Matrix containing `order` rows with the nᵗʰ one being the I/Q
|
||||||
|
vector that is the result of encoding the nᵗʰ message.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
return self.encoder(
|
||||||
|
util.messages_to_onehot(
|
||||||
|
torch.arange(0, self.order),
|
||||||
|
self.order
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
|
||||||
def get_random_messages(count, order):
|
def get_random_messages(count, order):
|
||||||
|
@ -34,3 +35,60 @@ def messages_to_onehot(messages, order):
|
||||||
])
|
])
|
||||||
"""
|
"""
|
||||||
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
return torch.nn.functional.one_hot(messages, num_classes=order).float()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_constellation(ax, constellation):
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
order = len(constellation)
|
||||||
|
color_map = matplotlib.cm.Dark2
|
||||||
|
color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
|
||||||
|
|
||||||
|
# Extend axes symmetrically around zero so that they fit data
|
||||||
|
axis_extent = max(
|
||||||
|
abs(constellation.min()),
|
||||||
|
abs(constellation.max())
|
||||||
|
) * 1.05
|
||||||
|
ax.set_xlim(-axis_extent, axis_extent)
|
||||||
|
ax.set_ylim(-axis_extent, axis_extent)
|
||||||
|
|
||||||
|
# Hide borders but keep ticks
|
||||||
|
for direction in ['left', 'bottom', 'right', 'top']:
|
||||||
|
ax.axis[direction].line.set_color('#00000000')
|
||||||
|
|
||||||
|
# Show zero-centered axes without ticks
|
||||||
|
for direction in ['xzero', 'yzero']:
|
||||||
|
axis = ax.axis[direction]
|
||||||
|
axis.set_visible(True)
|
||||||
|
axis.set_axisline_style('-|>')
|
||||||
|
axis.major_ticklabels.set_visible(False)
|
||||||
|
|
||||||
|
# Add axis names
|
||||||
|
ax.annotate(
|
||||||
|
'I', (1, 0.5), xycoords='axes fraction',
|
||||||
|
xytext=(25, 0), textcoords='offset points',
|
||||||
|
va='center', ha='right'
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.annotate(
|
||||||
|
'Q', (0.5, 1), xycoords='axes fraction',
|
||||||
|
xytext=(0, 25), textcoords='offset points',
|
||||||
|
va='center', ha='center'
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
*zip(*constellation.tolist()),
|
||||||
|
zorder=10,
|
||||||
|
s=60,
|
||||||
|
c=range(len(constellation)),
|
||||||
|
edgecolor='black',
|
||||||
|
cmap=color_map,
|
||||||
|
norm=color_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot center
|
||||||
|
center = constellation.sum(dim=0) / order
|
||||||
|
ax.scatter(
|
||||||
|
center[0], center[1],
|
||||||
|
marker='X',
|
||||||
|
)
|
||||||
|
|
94
train.py
94
train.py
|
@ -1,60 +1,114 @@
|
||||||
import constellation
|
import constellation
|
||||||
from constellation import util
|
from constellation import util
|
||||||
import torch
|
import torch
|
||||||
|
from matplotlib import pyplot
|
||||||
|
import matplotlib
|
||||||
|
from mpl_toolkits.axisartist.axislines import SubplotZero
|
||||||
|
import time
|
||||||
|
|
||||||
# Number of symbols to learn
|
# Number of symbols to learn
|
||||||
order = 4
|
order = 4
|
||||||
|
|
||||||
# Number of epochs
|
# Number of batches to skip between every loss report
|
||||||
num_epochs = 10000
|
loss_report_batch_skip = 500
|
||||||
|
|
||||||
# Number of epochs to skip between every loss report
|
# Size of batches during coarse optimization (small batches)
|
||||||
loss_report_epoch_skip = 500
|
coarse_batch_size = 8
|
||||||
|
|
||||||
|
# Size of batches during fine optimization (large batches)
|
||||||
|
fine_batch_size = 2048
|
||||||
|
|
||||||
# File in which the trained model is saved
|
# File in which the trained model is saved
|
||||||
output_file = 'output/constellation-order-{}.pth'.format(order)
|
output_file = 'output/constellation-order-{}.pth'.format(order)
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# Setup plot for showing training progress
|
||||||
|
fig = pyplot.figure()
|
||||||
|
ax = SubplotZero(fig, 111)
|
||||||
|
fig.add_subplot(ax)
|
||||||
|
|
||||||
|
pyplot.show(block=False)
|
||||||
|
|
||||||
|
# Train the model with random data
|
||||||
model = constellation.ConstellationNet(
|
model = constellation.ConstellationNet(
|
||||||
order=order,
|
order=order,
|
||||||
encoder_layers_sizes=(4,),
|
encoder_layers_sizes=(8,),
|
||||||
decoder_layers_sizes=(4,),
|
decoder_layers_sizes=(8,),
|
||||||
channel_model=constellation.GaussianChannel()
|
channel_model=constellation.GaussianChannel()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train the model with random data
|
print('Starting training\n')
|
||||||
model.train()
|
|
||||||
print('Starting training with {} epochs\n'.format(num_epochs))
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# Accumulated loss for last batches
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
prev_epoch_size_multiple = 0
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
# True in the first training phase where small batches are used, and false in
|
||||||
epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048
|
# the second phase where point positions are refined using large batches
|
||||||
|
is_coarse_optim = True
|
||||||
|
|
||||||
if epoch_size_multiple != prev_epoch_size_multiple:
|
# Current batch index
|
||||||
classes_ordered = torch.arange(order).repeat(epoch_size_multiple)
|
batch = 1
|
||||||
prev_epoch_size_multiple = epoch_size_multiple
|
|
||||||
|
|
||||||
|
# Current batch size
|
||||||
|
batch_size = coarse_batch_size
|
||||||
|
|
||||||
|
# List of training examples (not shuffled)
|
||||||
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||||
|
|
||||||
|
# Constellation from the previous training batch
|
||||||
|
prev_constellation = model.get_constellation()
|
||||||
|
total_change = float('inf')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Shuffle training data and convert to one-hot encoding
|
||||||
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
|
||||||
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
|
||||||
|
|
||||||
|
# Perform training step for current batch
|
||||||
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
predictions = model(onehot_dataset)
|
predictions = model(onehot_dataset)
|
||||||
loss = criterion(predictions, classes_dataset)
|
loss = criterion(predictions, classes_dataset)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Report loss
|
# Check for convergence
|
||||||
|
model.eval()
|
||||||
|
constellation = model.get_constellation()
|
||||||
|
total_change = (constellation - prev_constellation).abs().sum()
|
||||||
|
prev_constellation = constellation
|
||||||
|
|
||||||
|
if is_coarse_optim:
|
||||||
|
if total_change < 1e-5:
|
||||||
|
print('Changing to fine optimization')
|
||||||
|
is_coarse_optim = False
|
||||||
|
batch_size = fine_batch_size
|
||||||
|
classes_ordered = torch.arange(order).repeat(batch_size)
|
||||||
|
elif total_change < 1e-5:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Report loss and update figure (if applicable)
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
||||||
if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1:
|
if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
|
||||||
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
|
ax.clear()
|
||||||
print('Loss is {}'.format(running_loss / loss_report_epoch_skip))
|
util.plot_constellation(ax, constellation)
|
||||||
|
fig.canvas.draw()
|
||||||
|
pyplot.pause(1e-17)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print('Batch #{} (size {})'.format(batch + 1, batch_size))
|
||||||
|
print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
|
||||||
|
print('\tChange is {}\n'.format(total_change))
|
||||||
|
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
|
batch += 1
|
||||||
|
|
||||||
print('\nFinished training\n')
|
print('\nFinished training\n')
|
||||||
|
|
||||||
# Print some examples of reconstruction
|
# Print some examples of reconstruction
|
||||||
|
|
Loading…
Reference in New Issue