138 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			138 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
import constellation
 | 
						|
from constellation import util
 | 
						|
import torch
 | 
						|
from matplotlib import pyplot
 | 
						|
from mpl_toolkits.axisartist.axislines import SubplotZero
 | 
						|
 | 
						|
torch.manual_seed(42)
 | 
						|
 | 
						|
# Number of symbols to learn
 | 
						|
order = 4
 | 
						|
 | 
						|
# Number of batches to skip between every loss report
 | 
						|
loss_report_batch_skip = 500
 | 
						|
 | 
						|
# Size of batches during coarse optimization (small batches)
 | 
						|
coarse_batch_size = 8
 | 
						|
 | 
						|
# Size of batches during fine optimization (large batches)
 | 
						|
fine_batch_size = 2048
 | 
						|
 | 
						|
# File in which the trained model is saved
 | 
						|
output_file = 'output/constellation-order-{}.pth'.format(order)
 | 
						|
 | 
						|
###
 | 
						|
 | 
						|
# Setup plot for showing training progress
 | 
						|
fig = pyplot.figure()
 | 
						|
ax = SubplotZero(fig, 111)
 | 
						|
fig.add_subplot(ax)
 | 
						|
 | 
						|
pyplot.show(block=False)
 | 
						|
 | 
						|
# Train the model with random data
 | 
						|
model = constellation.ConstellationNet(
 | 
						|
    order=order,
 | 
						|
    encoder_layers_sizes=(8,),
 | 
						|
    decoder_layers_sizes=(8,),
 | 
						|
    channel_model=constellation.GaussianChannel()
 | 
						|
)
 | 
						|
 | 
						|
print('Starting training\n')
 | 
						|
 | 
						|
criterion = torch.nn.CrossEntropyLoss()
 | 
						|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 | 
						|
 | 
						|
# Accumulated loss for last batches
 | 
						|
running_loss = 0
 | 
						|
 | 
						|
# True in the first training phase where small batches are used, and false in
 | 
						|
# the second phase where point positions are refined using large batches
 | 
						|
is_coarse_optim = True
 | 
						|
 | 
						|
# Current batch index
 | 
						|
batch = 1
 | 
						|
 | 
						|
# Current batch size
 | 
						|
batch_size = coarse_batch_size
 | 
						|
 | 
						|
# List of training examples (not shuffled)
 | 
						|
classes_ordered = torch.arange(order).repeat(batch_size)
 | 
						|
 | 
						|
# Constellation from the previous training batch
 | 
						|
prev_constellation = model.get_constellation()
 | 
						|
total_change = float('inf')
 | 
						|
 | 
						|
while True:
 | 
						|
    # Shuffle training data and convert to one-hot encoding
 | 
						|
    classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
 | 
						|
    onehot_dataset = util.messages_to_onehot(classes_dataset, order)
 | 
						|
 | 
						|
    # Perform training step for current batch
 | 
						|
    model.train()
 | 
						|
    optimizer.zero_grad()
 | 
						|
    predictions = model(onehot_dataset)
 | 
						|
    loss = criterion(predictions, classes_dataset)
 | 
						|
    loss.backward()
 | 
						|
    optimizer.step()
 | 
						|
 | 
						|
    # Check for convergence
 | 
						|
    model.eval()
 | 
						|
    constellation = model.get_constellation()
 | 
						|
    total_change = (constellation - prev_constellation).abs().sum()
 | 
						|
    prev_constellation = constellation
 | 
						|
 | 
						|
    if is_coarse_optim:
 | 
						|
        if total_change < 1e-5:
 | 
						|
            print('Changing to fine optimization')
 | 
						|
            is_coarse_optim = False
 | 
						|
            batch_size = fine_batch_size
 | 
						|
            classes_ordered = torch.arange(order).repeat(batch_size)
 | 
						|
    elif total_change < 1e-5:
 | 
						|
        break
 | 
						|
 | 
						|
    # Report loss and update figure (if applicable)
 | 
						|
    running_loss += loss.item()
 | 
						|
 | 
						|
    if batch % loss_report_batch_skip == loss_report_batch_skip - 1:
 | 
						|
        print('Batch #{} (size {})'.format(batch + 1, batch_size))
 | 
						|
        print('\tLoss is {}'.format(running_loss / loss_report_batch_skip))
 | 
						|
        print('\tChange is {}\n'.format(total_change))
 | 
						|
 | 
						|
        ax.clear()
 | 
						|
        util.plot_constellation(
 | 
						|
            ax, constellation,
 | 
						|
            model.channel, model.decoder
 | 
						|
        )
 | 
						|
        fig.canvas.draw()
 | 
						|
        pyplot.pause(1e-17)
 | 
						|
 | 
						|
        running_loss = 0
 | 
						|
 | 
						|
    batch += 1
 | 
						|
 | 
						|
print('\nFinished training\n')
 | 
						|
 | 
						|
# Print some examples of reconstruction
 | 
						|
model.eval()
 | 
						|
print('Reconstruction examples:')
 | 
						|
print('Input vector\t\t\tOutput vector after softmax')
 | 
						|
 | 
						|
with torch.no_grad():
 | 
						|
    onehot_example = util.messages_to_onehot(torch.arange(0, order), order)
 | 
						|
    raw_output = model(onehot_example)
 | 
						|
    raw_output.required_grad = False
 | 
						|
    reconstructed_example = torch.nn.functional.softmax(raw_output, dim=1)
 | 
						|
 | 
						|
    for index in range(order):
 | 
						|
        print('{}\t\t{}'.format(
 | 
						|
            onehot_example[index].tolist(),
 | 
						|
            '[{}]'.format(', '.join(
 | 
						|
                '{:.5f}'.format(x)
 | 
						|
                for x in reconstructed_example[index].tolist()
 | 
						|
            ))
 | 
						|
        ))
 | 
						|
 | 
						|
print('\nSaving model as {}'.format(output_file))
 | 
						|
torch.save(model.state_dict(), output_file)
 |