diff options
-rw-r--r-- | snekbox/memfs.py | 31 | ||||
-rw-r--r-- | snekbox/nsjail.py | 33 | ||||
-rw-r--r-- | snekbox/utils/timed.py | 41 | ||||
-rw-r--r-- | tests/test_nsjail.py | 23 | ||||
-rw-r--r-- | tests/test_timed.py | 30 |
5 files changed, 117 insertions, 41 deletions
diff --git a/snekbox/memfs.py b/snekbox/memfs.py index ddea9a9..991766b 100644 --- a/snekbox/memfs.py +++ b/snekbox/memfs.py @@ -1,7 +1,9 @@ """Memory filesystem for snekbox.""" from __future__ import annotations +import glob import logging +import time import warnings import weakref from collections.abc import Generator @@ -125,6 +127,7 @@ class MemFS: limit: int, pattern: str = "**/*", exclude_files: dict[Path, float] | None = None, + timeout: float | None = None, ) -> Generator[FileAttachment, None, None]: """ Yields FileAttachments for files found in the output directory. @@ -135,12 +138,18 @@ class MemFS: exclude_files: A dict of Paths and last modified times. Files will be excluded if their last modified time is equal to the provided value. + timeout: Maximum time in seconds for file parsing. + Raises: + TimeoutError: If file parsing exceeds timeout. """ + start_time = time.monotonic() count = 0 - for file in self.output.rglob(pattern): - # Ignore hidden directories or files - if any(part.startswith(".") for part in file.parts): - log.info(f"Skipping hidden path {file!s}") + files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False) + for file in (Path(self.output, f) for f in files): + if timeout and (time.monotonic() - start_time) > timeout: + raise TimeoutError("File parsing timeout exceeded in MemFS.files") + + if not file.is_file(): continue if exclude_files and (orig_time := exclude_files.get(file)): @@ -154,10 +163,9 @@ class MemFS: log.info(f"Max attachments {limit} reached, skipping remaining files") break - if file.is_file(): - count += 1 - log.info(f"Found valid file for upload {file.name!r}") - yield FileAttachment.from_path(file, relative_to=self.output) + count += 1 + log.info(f"Found valid file for upload {file.name!r}") + yield FileAttachment.from_path(file, relative_to=self.output) def files_list( self, @@ -165,6 +173,7 @@ class MemFS: pattern: str, exclude_files: dict[Path, float] | None = None, preload_dict: bool = False, + timeout: float | None = None, ) -> list[FileAttachment]: """ Return a sorted list of file paths within the output directory. @@ -176,15 +185,21 @@ class MemFS: Files will be excluded if their last modified time is equal to the provided value. preload_dict: Whether to preload as_dict property data. + timeout: Maximum time in seconds for file parsing. Returns: List of FileAttachments sorted lexically by path name. + Raises: + TimeoutError: If file parsing exceeds timeout. """ + start_time = time.monotonic() res = sorted( self.files(limit=limit, pattern=pattern, exclude_files=exclude_files), key=lambda f: f.path, ) if preload_dict: for file in res: + if timeout and (time.monotonic() - start_time) > timeout: + raise TimeoutError("File parsing timeout exceeded in MemFS.files_list") # Loads the cached property as attribute _ = file.as_dict return res diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index f014850..f64830a 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -15,7 +15,7 @@ from snekbox.filesystem import Size from snekbox.memfs import MemFS from snekbox.process import EvalResult from snekbox.snekio import FileAttachment -from snekbox.utils.timed import timed +from snekbox.utils.timed import time_limit __all__ = ("NsJail",) @@ -56,7 +56,7 @@ class NsJail: memfs_home: str = "home", memfs_output: str = "home", files_limit: int | None = 100, - files_timeout: float | None = 8, + files_timeout: int | None = 5, files_pattern: str = "**/[!_]*", ): """ @@ -267,21 +267,32 @@ class NsJail: # Parse attachments with time limit try: - attachments = timed( - MemFS.files_list, - (fs, self.files_limit, self.files_pattern), - { - "preload_dict": True, - "exclude_files": files_written, - }, - timeout=self.files_timeout, - ) + with time_limit(self.files_timeout): + attachments = fs.files_list( + limit=self.files_limit, + pattern=self.files_pattern, + preload_dict=True, + exclude_files=files_written, + timeout=self.files_timeout, + ) log.info(f"Found {len(attachments)} files.") + except RecursionError: + log.info("Recursion error while parsing attachments") + return EvalResult( + args, + None, + "FileParsingError: Exceeded directory depth limit while parsing attachments", + ) except TimeoutError as e: log.info(f"Exceeded time limit while parsing attachments: {e}") return EvalResult( args, None, "TimeoutError: Exceeded time limit while parsing attachments" ) + except Exception as e: + log.exception(f"Unexpected {type(e).__name__} while parse attachments", exc_info=e) + return EvalResult( + args, None, "FileParsingError: Unknown error while parsing attachments" + ) log_lines = nsj_log.read().decode("utf-8").splitlines() if not log_lines and returncode == 255: diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py index 02388ff..11d126c 100644 --- a/snekbox/utils/timed.py +++ b/snekbox/utils/timed.py @@ -1,37 +1,34 @@ """Calling functions with time limits.""" -import multiprocessing -from collections.abc import Callable, Iterable, Mapping -from typing import Any, TypeVar +import signal +from collections.abc import Generator +from contextlib import contextmanager +from typing import TypeVar _T = TypeVar("_T") _V = TypeVar("_V") -__all__ = ("timed",) +__all__ = ("time_limit",) -def timed( - func: Callable[[_T], _V], - args: Iterable = (), - kwds: Mapping[str, Any] | None = None, - timeout: float | None = None, -) -> _V: +@contextmanager +def time_limit(timeout: int | None = None) -> Generator[None, None, None]: """ - Call a function with a time limit. + Decorator to call a function with a time limit. Args: - func: Function to call. - args: Arguments for function. - kwds: Keyword arguments for function. timeout: Timeout limit in seconds. Raises: TimeoutError: If the function call takes longer than `timeout` seconds. """ - if kwds is None: - kwds = {} - with multiprocessing.Pool(1) as pool: - result = pool.apply_async(func, args, kwds) - try: - return result.get(timeout) - except multiprocessing.TimeoutError as e: - raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e + + def signal_handler(_signum, _frame): + raise TimeoutError(f"time_limit call timed out after {timeout} seconds.") + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(timeout) + + try: + yield + finally: + signal.alarm(0) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 456046b..c701d3a 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -227,6 +227,29 @@ class NsJailTests(unittest.TestCase): ) self.assertEqual(result.stderr, None) + def test_file_parsing_depth_limit(self): + code = dedent( + """ + import os + + x = "" + for _ in range(1000): + x += "a/" + os.mkdir(x) + + open(f"{x}test.txt", "w").write("test") + """ + ).strip() + + nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=5) + result = nsjail.python3(["-c", code]) + self.assertEqual(result.returncode, None) + self.assertEqual( + result.stdout, + "FileParsingError: Exceeded directory depth limit while parsing attachments", + ) + self.assertEqual(result.stderr, None) + def test_file_write_error(self): """Test errors during file write.""" result = self.nsjail.python3( diff --git a/tests/test_timed.py b/tests/test_timed.py new file mode 100644 index 0000000..e46bd37 --- /dev/null +++ b/tests/test_timed.py @@ -0,0 +1,30 @@ +import math +import time +from unittest import TestCase + +from snekbox.utils.timed import time_limit + + +class TimedTests(TestCase): + def test_sleep(self): + """Test that a sleep can be interrupted.""" + _finished = False + start = time.perf_counter() + with self.assertRaises(TimeoutError): + with time_limit(1): + time.sleep(2) + _finished = True + end = time.perf_counter() + self.assertLess(end - start, 2) + self.assertFalse(_finished) + + def test_iter(self): + """Test that a long-running built-in function can be interrupted.""" + _result = 0 + start = time.perf_counter() + with self.assertRaises(TimeoutError): + with time_limit(1): + _result = math.factorial(2**30) + end = time.perf_counter() + self.assertEqual(_result, 0) + self.assertLess(end - start, 2) |