diff options
-rw-r--r-- | snekbox/nsjail.py | 10 | ||||
-rw-r--r-- | snekbox/utils/timed.py | 9 | ||||
-rw-r--r-- | tests/test_nsjail.py | 34 |
3 files changed, 40 insertions, 13 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 006ff98..fb4ff1d 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -241,10 +241,12 @@ class NsJail: (fs, self.files_limit, self.files_pattern), timeout=self.files_timeout, ) - log.info(f"Found {len(attachments)} attachments.") - except TimeoutError as err: - log.info(f"Exceeded time limit in parsing attachments: {err}") - return EvalResult(args, returncode, f"AttachmentError: {err}") + log.info(f"Found {len(attachments)} files.") + except TimeoutError as e: + log.warning(f"Exceeded time limit while parsing attachments: {e}") + return EvalResult( + args, None, "TimeoutError: Exceeded time limit while parsing attachments" + ) log_lines = nsj_log.read().decode("utf-8").splitlines() if not log_lines and returncode == 255: diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py index e1602b5..4acfa0d 100644 --- a/snekbox/utils/timed.py +++ b/snekbox/utils/timed.py @@ -1,6 +1,6 @@ """Calling functions with time limits.""" +import multiprocessing from collections.abc import Callable, Iterable, Mapping -from multiprocessing import Pool from typing import Any, TypeVar _T = TypeVar("_T") @@ -23,6 +23,9 @@ def timed( """ if kwds is None: kwds = {} - with Pool(1) as pool: + with multiprocessing.Pool(1) as pool: result = pool.apply_async(func, args, kwds) - return result.get(timeout) + try: + return result.get(timeout) + except multiprocessing.TimeoutError as e: + raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 493e1f1..dca0a8f 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -17,19 +17,17 @@ class NsJailTests(unittest.TestCase): def setUp(self): super().setUp() - self.nsjail = NsJail() - # Set a lower memfs size limit so tests don't exceed time limit - self.nsjail.memfs_instance_size = 2 * 1024 * 1024 # 2MiB - + # Specify lower limits for unit tests to complete within time limits + self.nsjail = NsJail(memfs_instance_size=2 * 1024 * 1024) self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) def eval_code(self, code: str): return self.nsjail.python3(["-c", code]) - def eval_file(self, code: str, name: str = "test.py"): + def eval_file(self, code: str, name: str = "test.py", **kwargs): file = FileAttachment(name, code) - return self.nsjail.python3([name], [file]) + return self.nsjail.python3([name], [file], **kwargs) def test_print_returns_0(self): for fn in (self.eval_code, self.eval_file): @@ -185,6 +183,30 @@ class NsJailTests(unittest.TestCase): self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) + def test_file_parsing_timeout(self): + code = dedent( + """ + import os + data = "a" * 1024 + size = 32 * 1024 * 1024 + + with open("src", "w") as f: + for _ in range((size // 1024) - 5): + f.write(data) + + for i in range(100): + os.symlink("src", f"output{i}") + """ + ).strip() + + nsjail = NsJail(memfs_instance_size=48 * 1024 * 1024, files_timeout=1) + result = nsjail.python3(["-c", code]) + self.assertEqual(result.returncode, None) + self.assertEqual( + result.stdout, "TimeoutError: Exceeded time limit while parsing attachments" + ) + self.assertEqual(result.stderr, None) + def test_sigsegv_returns_139(self): # In honour of Juan. code = dedent( """ |