diff --git a/constellation/ConstellationNet.py b/constellation/ConstellationNet.py index 2e9d93a..39fe417 100644 --- a/constellation/ConstellationNet.py +++ b/constellation/ConstellationNet.py @@ -1,6 +1,8 @@ import torch.nn as nn +import torch from .GaussianChannel import GaussianChannel from .NormalizePower import NormalizePower +from . import util class ConstellationNet(nn.Module): @@ -30,6 +32,7 @@ class ConstellationNet(nn.Module): encoder and decoder network. """ super().__init__() + self.order = order # Build the encoder network taking a one-hot encoded message as input # and outputting an I/Q vector. The network additionally uses hidden @@ -78,3 +81,18 @@ class ConstellationNet(nn.Module): symbol = self.encoder(x) noisy_symbol = self.channel(symbol) return self.decoder(noisy_symbol) + + def get_constellation(self): + """ + Extract symbol constellation out of the trained encoder. + + :return: Matrix containing `order` rows with the nᵗʰ one being the I/Q + vector that is the result of encoding the nᵗʰ message. + """ + with torch.no_grad(): + return self.encoder( + util.messages_to_onehot( + torch.arange(0, self.order), + self.order + ) + ) diff --git a/constellation/util.py b/constellation/util.py index 3b4197f..4aabf06 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -1,4 +1,5 @@ import torch +import matplotlib def get_random_messages(count, order): @@ -34,3 +35,60 @@ def messages_to_onehot(messages, order): ]) """ return torch.nn.functional.one_hot(messages, num_classes=order).float() + + +def plot_constellation(ax, constellation): + ax.grid() + + order = len(constellation) + color_map = matplotlib.cm.Dark2 + color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) + + # Extend axes symmetrically around zero so that they fit data + axis_extent = max( + abs(constellation.min()), + abs(constellation.max()) + ) * 1.05 + ax.set_xlim(-axis_extent, axis_extent) + ax.set_ylim(-axis_extent, axis_extent) + + # Hide borders but keep ticks + for direction in ['left', 'bottom', 'right', 'top']: + ax.axis[direction].line.set_color('#00000000') + + # Show zero-centered axes without ticks + for direction in ['xzero', 'yzero']: + axis = ax.axis[direction] + axis.set_visible(True) + axis.set_axisline_style('-|>') + axis.major_ticklabels.set_visible(False) + + # Add axis names + ax.annotate( + 'I', (1, 0.5), xycoords='axes fraction', + xytext=(25, 0), textcoords='offset points', + va='center', ha='right' + ) + + ax.annotate( + 'Q', (0.5, 1), xycoords='axes fraction', + xytext=(0, 25), textcoords='offset points', + va='center', ha='center' + ) + + ax.scatter( + *zip(*constellation.tolist()), + zorder=10, + s=60, + c=range(len(constellation)), + edgecolor='black', + cmap=color_map, + norm=color_norm, + ) + + # Plot center + center = constellation.sum(dim=0) / order + ax.scatter( + center[0], center[1], + marker='X', + ) diff --git a/train.py b/train.py index 838f70d..8dec099 100644 --- a/train.py +++ b/train.py @@ -1,60 +1,114 @@ import constellation from constellation import util import torch +from matplotlib import pyplot +import matplotlib +from mpl_toolkits.axisartist.axislines import SubplotZero +import time # Number of symbols to learn order = 4 -# Number of epochs -num_epochs = 10000 +# Number of batches to skip between every loss report +loss_report_batch_skip = 500 -# Number of epochs to skip between every loss report -loss_report_epoch_skip = 500 +# Size of batches during coarse optimization (small batches) +coarse_batch_size = 8 + +# Size of batches during fine optimization (large batches) +fine_batch_size = 2048 # File in which the trained model is saved output_file = 'output/constellation-order-{}.pth'.format(order) +### + +# Setup plot for showing training progress +fig = pyplot.figure() +ax = SubplotZero(fig, 111) +fig.add_subplot(ax) + +pyplot.show(block=False) + +# Train the model with random data model = constellation.ConstellationNet( order=order, - encoder_layers_sizes=(4,), - decoder_layers_sizes=(4,), + encoder_layers_sizes=(8,), + decoder_layers_sizes=(8,), channel_model=constellation.GaussianChannel() ) -# Train the model with random data -model.train() -print('Starting training with {} epochs\n'.format(num_epochs)) +print('Starting training\n') criterion = torch.nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters()) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +# Accumulated loss for last batches running_loss = 0 -prev_epoch_size_multiple = 0 -for epoch in range(num_epochs): - epoch_size_multiple = 8 if epoch < num_epochs / 2 else 2048 +# True in the first training phase where small batches are used, and false in +# the second phase where point positions are refined using large batches +is_coarse_optim = True - if epoch_size_multiple != prev_epoch_size_multiple: - classes_ordered = torch.arange(order).repeat(epoch_size_multiple) - prev_epoch_size_multiple = epoch_size_multiple +# Current batch index +batch = 1 +# Current batch size +batch_size = coarse_batch_size + +# List of training examples (not shuffled) +classes_ordered = torch.arange(order).repeat(batch_size) + +# Constellation from the previous training batch +prev_constellation = model.get_constellation() +total_change = float('inf') + +while True: + # Shuffle training data and convert to one-hot encoding classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))] onehot_dataset = util.messages_to_onehot(classes_dataset, order) + # Perform training step for current batch + model.train() optimizer.zero_grad() predictions = model(onehot_dataset) loss = criterion(predictions, classes_dataset) loss.backward() optimizer.step() - # Report loss + # Check for convergence + model.eval() + constellation = model.get_constellation() + total_change = (constellation - prev_constellation).abs().sum() + prev_constellation = constellation + + if is_coarse_optim: + if total_change < 1e-5: + print('Changing to fine optimization') + is_coarse_optim = False + batch_size = fine_batch_size + classes_ordered = torch.arange(order).repeat(batch_size) + elif total_change < 1e-5: + break + + # Report loss and update figure (if applicable) running_loss += loss.item() - if epoch % loss_report_epoch_skip == loss_report_epoch_skip - 1: - print('Epoch {}/{}'.format(epoch + 1, num_epochs)) - print('Loss is {}'.format(running_loss / loss_report_epoch_skip)) + if batch % loss_report_batch_skip == loss_report_batch_skip - 1: + ax.clear() + util.plot_constellation(ax, constellation) + fig.canvas.draw() + pyplot.pause(1e-17) + time.sleep(0.1) + + print('Batch #{} (size {})'.format(batch + 1, batch_size)) + print('\tLoss is {}'.format(running_loss / loss_report_batch_skip)) + print('\tChange is {}\n'.format(total_change)) + running_loss = 0 + batch += 1 + print('\nFinished training\n') # Print some examples of reconstruction