from enum import Enum, auto from functools import partial from itertools import product from multiprocessing import Pool from subprocess import DEVNULL, PIPE, Popen, TimeoutExpired from typing import List chars = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' class Status(Enum): """Possible correctness statuses for a program.""" Correct = auto() WrongAnswer = auto() Timeout = auto() Invalid = auto() def check_pair(script, instr, outstr, timeout) -> Status: """ Check that a Bash script outputs a given string when given a input string. :param script: script to execute :param instr: string given to the script’s stdin :param outstr: expected stdout value :param timeout: maximum allowed time in seconds :returns: status indicating how the script behaved """ process = Popen( [ "/bin/bash", "--restricted", "-c", "--", "trap 'kill -9 $(jobs -p) && wait' SIGINT SIGTERM EXIT;\n" + script, ], stdin=PIPE, stdout=PIPE, stderr=DEVNULL, ) try: stdout, stderr = process.communicate(instr.encode(), timeout) if process.returncode in (2, 126, 127, 128): return Status.Invalid if process.returncode != 0 or stdout != outstr.encode(): return Status.WrongAnswer return Status.Correct except TimeoutExpired: try: process.terminate() stdout, stderr = process.communicate() except ProcessLookupError: pass return Status.Timeout def check_script(pairs, timeout, script) -> Status: """ Check that a Bash script satisfies a set of test cases. :param pairs: input/expected output pairs :param timeout: maximum allowed time in seconds :param script: script to test :returns: status indicating how the script behaved """ for pair in pairs: status = check_pair(script, *pair, timeout) if status != Status.Correct: return script, status return script, status def generate_scripts(max_length): """ Generate all scripts up to a given length. :param max_length: maximum length to generate :yields: generated scripts """ for length in range(max_length + 1): for letters in product(chars, repeat=length): yield "".join(letters) def find_script( pairs, max_length, processes, timeout, invalid_prefix, out_log, ) -> List[str]: """ Find scripts that satisfy the given set of test cases. :param pairs: input/expected output pairs :param max_length: maximum script length to test :param processes: number of parallel processes to spawn :param timeout: maximum allowed time in seconds for each script run :param invalid_prefix: prefix to the files in which invalid scripts are to be stored, one file per script length :param out_log: stream to which progress logs are written :returns: list of matching scripts """ candidates = [] bound_check_script = partial(check_script, pairs, timeout) out_invalid = [] for i in range(max_length + 1): out_invalid.append(open(invalid_prefix + str(i), "w")) chars_count = len(chars) num_tasks = int((chars_count ** (max_length + 1) - 1) / (chars_count - 1)) done_tasks = 0 with Pool(processes) as pool: for script, status in pool.imap_unordered( bound_check_script, generate_scripts(max_length), chunksize=10, ): done_tasks += 1 if done_tasks % 10000 == 0: print( f"Progress: {done_tasks}/{num_tasks} \ {done_tasks / num_tasks * 100:.1f}%", file=out_log, flush=True ) if status == Status.Correct: print( f"> Found candidate: '{script}'", file=out_log, flush=True ) candidates.append(script) if status == Status.Invalid: print(script, file=out_invalid[len(script)]) for file in out_invalid: file.close() return candidates