from enum import Enum, auto from functools import partial from itertools import product from multiprocessing import Pool from subprocess import DEVNULL, PIPE, Popen, TimeoutExpired from time import time from typing import List chars = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' class Status(Enum): """Possible correctness statuses for a program.""" Correct = auto() # returns the expected output and a 0 exit code Invalid = auto() # has a syntax error or uses an undefined command WrongAnswer = auto() # returns an unknown exit code or invalid output Timeout = auto() # ran for more time than allowed by the timeout value def check_pair(script, instr, outstr, timeout) -> Status: """ Check that a Bash script outputs a given string when given a input string. :param script: script to execute :param instr: string given to the script’s stdin :param outstr: expected stdout value :param timeout: maximum allowed time in seconds :returns: status indicating how the script behaved """ process = Popen( [ "/bin/bash", "--restricted", "-c", "--", "trap 'kill -9 $(jobs -p) && wait' SIGINT SIGTERM EXIT;\n" + script, ], stdin=PIPE, stdout=PIPE, stderr=DEVNULL, ) try: stdout, stderr = process.communicate(instr.encode(), timeout) if process.returncode in (2, 126, 127, 128): return Status.Invalid if process.returncode != 0 or stdout != outstr.encode(): return Status.WrongAnswer return Status.Correct except TimeoutExpired: try: process.terminate() stdout, stderr = process.communicate() except ProcessLookupError: pass return Status.Timeout def check_script(pairs, timeout, script) -> Status: """ Check that a Bash script satisfies a set of test cases. :param pairs: input/expected output pairs :param timeout: maximum allowed time in seconds :param script: script to test :returns: status indicating how the script behaved """ for pair in pairs: status = check_pair(script, *pair, timeout) if status != Status.Correct: return script, status return script, status def generate_scripts(max_length): """ Generate all scripts up to a given length. :param max_length: maximum length to generate :yields: generated scripts """ for length in range(max_length + 1): for letters in product(chars, repeat=length): yield "".join(letters) def find_script( pairs, max_length, processes, timeout, out_valid_prefix, out_log, ) -> List[str]: """ Find scripts that satisfy the given set of test cases. :param pairs: input/expected output pairs :param max_length: maximum script length to test :param processes: number of parallel processes to spawn :param timeout: maximum allowed time in seconds for each script run :param out_valid_prefix: store valid scripts (includes correct scripts and incorrect scripts that do not contain syntax or runtime errors) into files starting with this prefix :param out_log: stream to which progress logs are written :returns: list of matching scripts """ candidates = [] out_valid = [] bound_check_script = partial(check_script, pairs, timeout) for i in range(max_length + 1): out_valid.append(open(out_valid_prefix + str(i), "w")) chars_count = len(chars) num_tasks = int((chars_count ** (max_length + 1) - 1) / (chars_count - 1)) done_tasks = 0 start_time = time() with Pool(processes) as pool: for script, status in pool.imap_unordered( bound_check_script, generate_scripts(max_length), chunksize=10, ): done_tasks += 1 if done_tasks % 10000 == 0: print( f"Progress: {done_tasks}/{num_tasks} \ {done_tasks / num_tasks * 100:.1f}% \ (running for {time() - start_time:.1f}s)", file=out_log, flush=True ) if status == Status.Correct: print( f"> Found candidate: '{script}'", file=out_log, flush=True ) candidates.append(script) if status != Status.Invalid: print(script, file=out_valid[len(script)], flush=True) for file in out_valid: file.close() print(f"Finished in {time() - start_time:.1f}s", file=out_log, flush=True) return candidates