|
|
@ -4,6 +4,7 @@ import math |
|
|
|
import torch |
|
|
|
import time |
|
|
|
import pickle |
|
|
|
import sys |
|
|
|
|
|
|
|
# Number of seconds to wait between each checkpoint |
|
|
|
time_between_checkpoints = 10 * 60 # 10 minutes |
|
|
@ -151,23 +152,66 @@ def save_results(results, path): |
|
|
|
pickle.dump(results, file, pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
|
|
|
|
|
|
# List of all configurations to be tested |
|
|
|
all_confs = list(generate_test_configurations()) |
|
|
|
|
|
|
|
# Number of splits of the configuration list |
|
|
|
parts_count = 1 |
|
|
|
|
|
|
|
# Current split of the configuration list |
|
|
|
current_part = 0 |
|
|
|
|
|
|
|
if len(sys.argv) == 2: |
|
|
|
print('Please specify which part must be evaluated.', file=sys.stderr) |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
if len(sys.argv) == 3: |
|
|
|
parts_count = int(sys.argv[1]) |
|
|
|
current_part = int(sys.argv[2]) - 1 |
|
|
|
|
|
|
|
if parts_count < 1: |
|
|
|
print('There must be at least one part.', file=sys.stderr) |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
if current_part < 0 or current_part >= parts_count: |
|
|
|
print( |
|
|
|
'Current part must be between 1 and the number of parts.', |
|
|
|
file=sys.stderr |
|
|
|
) |
|
|
|
sys.exit(1) |
|
|
|
|
|
|
|
# Starting/ending index of configurations to be tested |
|
|
|
part_start_index = math.floor(current_part * len(all_confs) / parts_count) |
|
|
|
part_end_index = math.floor((current_part + 1) * len(all_confs) / parts_count) |
|
|
|
part_size = part_end_index - part_start_index |
|
|
|
|
|
|
|
if parts_count == 1: |
|
|
|
print('Evaluating the whole set of configurations') |
|
|
|
print('Use “{} [parts_count] [current_part]” to divide it'.format( |
|
|
|
sys.argv[0] |
|
|
|
)) |
|
|
|
else: |
|
|
|
print('Evaluating part {}/{} of the set of configurations'.format( |
|
|
|
current_part, |
|
|
|
parts_count |
|
|
|
)) |
|
|
|
print('(indices {} to {})'.format(part_start_index, part_end_index - 1)) |
|
|
|
|
|
|
|
print() |
|
|
|
|
|
|
|
# Current set of results |
|
|
|
results = {} |
|
|
|
|
|
|
|
# Go through all configurations to be tested |
|
|
|
all_confs = list(generate_test_configurations()) |
|
|
|
|
|
|
|
# Last checkpoint save time |
|
|
|
last_save_time = 0 |
|
|
|
|
|
|
|
for conf in all_confs: |
|
|
|
for conf in all_confs[part_start_index:part_end_index]: |
|
|
|
key = tuple(sorted(conf.items())) |
|
|
|
results[key] = evaluate_parameters(conf) |
|
|
|
|
|
|
|
print('{}/{} configurations tested ({:.1f} %)'.format( |
|
|
|
len(results), |
|
|
|
len(all_confs), |
|
|
|
100 * len(results) / len(all_confs), |
|
|
|
len(results), part_size, |
|
|
|
100 * len(results) / part_size, |
|
|
|
)) |
|
|
|
|
|
|
|
current_time = math.floor(time.time()) |
|
|
|