constellationnet/experiment.py

231 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import constellation
from constellation import util
import math
import torch
import time
import pickle
import random
import sys
# Number of seconds to wait between each checkpoint
time_between_checkpoints = 10 * 60 # 10 minutes
# Format for checkpoint files
checkpoint_path = 'output/experiment-{}.pkl'
def train_with_parameters(
order,
layer_sizes,
initial_learning_rate,
batch_size
):
"""
Report final loss after fully learning a constellation with given
parameters.
:param order: Number of symbols in the constellation.
:param layer_sizes: Shape of the encoders hidden layers. The
size of this sequence is the number of hidden layers, with each element
being a number which specifies the number of neurons in its channel. The
decoders hidden layers will be of the same shape but reversed.
:param initial_learning_rate: Initial learning rate used for the optimizer.
:param batch_size: Number of training examples for each training batch
expressed as a multiple of the constellation order.
"""
model = constellation.ConstellationNet(
order=order,
encoder_layers=layer_sizes,
decoder_layers=layer_sizes[::-1],
)
# List of training examples (not shuffled)
classes_ordered = torch.arange(order).repeat(batch_size)
# Constellation from the previous training batch
prev_constel = model.get_constellation()
total_change = float('inf')
# Optimizer settings
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=0.25,
patience=100,
cooldown=50,
threshold=1e-8
)
while total_change >= 1e-4:
# Shuffle training data and convert to one-hot encoding
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
# Perform training step for current batch
model.train()
optimizer.zero_grad()
predictions = model(onehot_dataset)
loss = criterion(predictions, classes_dataset)
loss.backward()
optimizer.step()
# Update learning rate scheduler
scheduler.step(loss)
# Check for convergence
model.eval()
cur_constel = model.get_constellation()
total_change = (cur_constel - prev_constel).norm(dim=1).sum()
prev_constel = cur_constel
# Compute final loss value
with torch.no_grad():
classes_ordered = torch.arange(order).repeat(2048)
classes_dataset = classes_ordered[torch.randperm(len(classes_ordered))]
onehot_dataset = util.messages_to_onehot(classes_dataset, order)
predictions = model(onehot_dataset)
return criterion(predictions, classes_dataset).tolist()
def evaluate_parameters(parameters, num_repeats=3):
"""
Run constellation training several times and keep the lowest reached loss.
:param parameters: Training parameters (see `train_with_parameters` for
documentation).
:param num_repeats: Number of runs.
:return: Lowest reached loss.
"""
minimal_loss = float('inf')
for run_index in range(num_repeats):
current_loss = train_with_parameters(**parameters)
minimal_loss = min(minimal_loss, current_loss)
return minimal_loss
def generate_test_configurations():
"""
Generate the set of all configurations to be tested.
:yield: Configuration as a dictionary of parameters.
"""
# Cartesian product of independent variables
independent_vars = util.product_dict(
order=[4, 16, 32],
initial_learning_rate=[10 ** x for x in range(-2, 1)],
batch_size=[8, 2048],
)
# Add dependent variables
for current_dict in independent_vars:
for first_layer in range(0, current_dict['order'] + 1, 4):
for last_layer in range(0, first_layer + 1, 4):
# Convert pair of sizes for each layer to a shape tuple
if first_layer == 0 and last_layer == 0:
layer_sizes = ()
elif first_layer != 0 and last_layer == 0:
layer_sizes = (first_layer,)
elif first_layer == 0 and last_layer != 0:
layer_sizes = (last_layer,)
else: # first_layer != 0 and last_layer != 0
layer_sizes = (first_layer, last_layer)
# Merge dependent variables with independent ones
yield {
**current_dict,
'layer_sizes': layer_sizes
}
def save_results(results, path):
"""
Save current results of experiment.
:param results: Dictionary containing current results.
:param path: Path to the file where results are to be saved.
"""
with open(path, 'wb') as file:
pickle.dump(results, file, pickle.HIGHEST_PROTOCOL)
# List of all configurations to be tested
random.seed(42)
all_confs = list(generate_test_configurations())
random.shuffle(all_confs)
# Number of splits of the configuration list
parts_count = 1
# Current split of the configuration list
current_part = 0
if len(sys.argv) == 2:
print('Please specify which part must be evaluated.', file=sys.stderr)
sys.exit(1)
if len(sys.argv) == 3:
parts_count = int(sys.argv[1])
current_part = int(sys.argv[2]) - 1
if parts_count < 1:
print('There must be at least one part.', file=sys.stderr)
sys.exit(1)
if current_part < 0 or current_part >= parts_count:
print(
'Current part must be between 1 and the number of parts.',
file=sys.stderr
)
sys.exit(1)
# Starting/ending index of configurations to be tested
part_start_index = math.floor(current_part * len(all_confs) / parts_count)
part_end_index = math.floor((current_part + 1) * len(all_confs) / parts_count)
part_size = part_end_index - part_start_index
if parts_count == 1:
print('Evaluating the whole set of configurations')
print('Use “{} [parts_count] [current_part]” to divide it'.format(
sys.argv[0]
))
else:
print('Evaluating part {}/{} of the set of configurations'.format(
current_part,
parts_count
))
print('(indices {} to {})'.format(part_start_index, part_end_index - 1))
print()
# Current set of results
results = {}
# Last checkpoint save time
last_save_time = 0
for conf in all_confs[part_start_index:part_end_index]:
key = tuple(sorted(conf.items()))
results[key] = evaluate_parameters(conf)
print('{}/{} configurations tested ({:.1f} %)'.format(
len(results), part_size,
100 * len(results) / part_size,
))
current_time = math.floor(time.time())
if current_time - last_save_time >= time_between_checkpoints:
current_path = checkpoint_path.format(current_time)
save_results(results, current_path)
print('Saved checkpoint to {}'.format(current_path))
last_save_time = current_time
# Save final checkpoint
output_path = checkpoint_path.format('final')
save_results(results, output_path)
print('Saved results to {}'.format(output_path))