diff options
-rw-r--r-- | config/snekbox.cfg | 1 | ||||
-rw-r--r-- | snekbox/nsjail.py | 33 | ||||
-rw-r--r-- | tests/test_nsjail.py | 9 |
3 files changed, 38 insertions, 5 deletions
diff --git a/config/snekbox.cfg b/config/snekbox.cfg index 5b47459..27caf27 100644 --- a/config/snekbox.cfg +++ b/config/snekbox.cfg @@ -18,7 +18,6 @@ envar: "NUMEXPR_NUM_THREADS=1" keep_caps: false rlimit_as: 700 -rlimit_fsize: 10 clone_newnet: true clone_newuser: true diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index cafde6d..b792ba9 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -27,6 +27,9 @@ NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail") NSJAIL_CFG = os.getenv("NSJAIL_CFG", "./config/snekbox.cfg") MEM_MAX = 52428800 +# Limit of stdout bytes we consume before terminating nsjail +OUTPUT_MAX = 1_000_000 # 1 MB + class NsJail: """ @@ -124,7 +127,7 @@ class NsJail: log.info(msg) try: - result = subprocess.run( + nsjail = subprocess.Popen( args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -133,11 +136,33 @@ class NsJail: except ValueError: return CompletedProcess(args, None, "ValueError: embedded null byte", None) + output_size = 0 + output = [] + + # We'll consume STDOUT as long as the subprocess is running + while nsjail.poll() is None: + # Read 100 characters from the STDOUT stream + chars = nsjail.stdout.read(100) + chars_size = sys.getsizeof(chars) + + # Check if these characters take us over the output limit + if output_size + chars_size > OUTPUT_MAX: + # Ask nsjail to terminate itself using SIGTERM + nsjail.terminate() + break + + output_size += chars_size + output.append(chars) + + # Ensure that we wait for the nsjail process to terminate + nsjail.wait() + log_lines = nsj_log.read().decode("utf-8").splitlines() - if not log_lines and result.returncode == 255: + if not log_lines and nsjail.returncode == 255: # NsJail probably failed to parse arguments so log output will still be in stdout - log_lines = result.stdout.splitlines() + log_lines = "".join(output).splitlines() self._parse_log(log_lines) - return result + log.info(f"nsjail return code: {nsjail.returncode}") + return CompletedProcess(args, nsjail.returncode, "".join(output), None) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 0b755b2..852be4b 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -174,3 +174,12 @@ class NsJailTests(unittest.TestCase): msg="stdout does not come before stderr" ) self.assertEqual(result.stderr, None) + + def test_stdout_flood_results_in_graceful_sigterm(self): + stdout_flood = dedent(""" + while True: + print('abcdefghij') + """).strip() + + result = self.nsjail.python3(stdout_flood) + self.assertEqual(result.returncode, 143) |