From 8d64f916619987c74ab4a96aa91290164a550ecc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Delabre?= Date: Mon, 16 Dec 2019 13:27:43 -0500 Subject: [PATCH] Add partitioning to experiment --- experiment.py | 58 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/experiment.py b/experiment.py index 0973b6a..3cfea10 100644 --- a/experiment.py +++ b/experiment.py @@ -4,6 +4,7 @@ import math import torch import time import pickle +import sys # Number of seconds to wait between each checkpoint time_between_checkpoints = 10 * 60 # 10 minutes @@ -151,23 +152,66 @@ def save_results(results, path): pickle.dump(results, file, pickle.HIGHEST_PROTOCOL) +# List of all configurations to be tested +all_confs = list(generate_test_configurations()) + +# Number of splits of the configuration list +parts_count = 1 + +# Current split of the configuration list +current_part = 0 + +if len(sys.argv) == 2: + print('Please specify which part must be evaluated.', file=sys.stderr) + sys.exit(1) + +if len(sys.argv) == 3: + parts_count = int(sys.argv[1]) + current_part = int(sys.argv[2]) - 1 + + if parts_count < 1: + print('There must be at least one part.', file=sys.stderr) + sys.exit(1) + + if current_part < 0 or current_part >= parts_count: + print( + 'Current part must be between 1 and the number of parts.', + file=sys.stderr + ) + sys.exit(1) + +# Starting/ending index of configurations to be tested +part_start_index = math.floor(current_part * len(all_confs) / parts_count) +part_end_index = math.floor((current_part + 1) * len(all_confs) / parts_count) +part_size = part_end_index - part_start_index + +if parts_count == 1: + print('Evaluating the whole set of configurations') + print('Use “{} [parts_count] [current_part]” to divide it'.format( + sys.argv[0] + )) +else: + print('Evaluating part {}/{} of the set of configurations'.format( + current_part, + parts_count + )) + print('(indices {} to {})'.format(part_start_index, part_end_index - 1)) + +print() + # Current set of results results = {} -# Go through all configurations to be tested -all_confs = list(generate_test_configurations()) - # Last checkpoint save time last_save_time = 0 -for conf in all_confs: +for conf in all_confs[part_start_index:part_end_index]: key = tuple(sorted(conf.items())) results[key] = evaluate_parameters(conf) print('{}/{} configurations tested ({:.1f} %)'.format( - len(results), - len(all_confs), - 100 * len(results) / len(all_confs), + len(results), part_size, + 100 * len(results) / part_size, )) current_time = math.floor(time.time())