diff options
author | 2023-03-13 14:13:37 -0400 | |
---|---|---|
committer | 2023-03-13 14:13:37 -0400 | |
commit | 7ca391715068d5bc7627265f83953b7cb3851b71 (patch) | |
tree | 4032e8c1a46d96f288c11285549a0302007f6759 | |
parent | Add unit test for deeply nested path file parsing (diff) |
Add SIGALRM based time limit
-rw-r--r-- | snekbox/nsjail.py | 25 | ||||
-rw-r--r-- | snekbox/utils/timed.py | 30 |
2 files changed, 42 insertions, 13 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 8bbcf22..77cd321 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -15,7 +15,7 @@ from snekbox.filesystem import Size from snekbox.memfs import MemFS from snekbox.process import EvalResult from snekbox.snekio import FileAttachment -from snekbox.utils.timed import timed +from snekbox.utils.timed import time_limit __all__ = ("NsJail",) @@ -267,16 +267,14 @@ class NsJail: # Parse attachments with time limit try: - attachments = timed( - MemFS.files_list, - (fs, self.files_limit, self.files_pattern), - { - "preload_dict": True, - "exclude_files": files_written, - "timeout": self.files_timeout, - }, - timeout=self.files_timeout, - ) + with time_limit(self.files_timeout): + attachments = fs.files_list( + limit=self.files_limit, + pattern=self.files_pattern, + preload_dict=True, + exclude_files=files_written, + timeout=self.files_timeout, + ) log.info(f"Found {len(attachments)} files.") except RecursionError: log.info("Recursion error while parsing attachments") @@ -290,6 +288,11 @@ class NsJail: return EvalResult( args, None, "TimeoutError: Exceeded time limit while parsing attachments" ) + except Exception as e: + log.error(f"Unexpected {type(e).__name__} while parse attachments: {e}") + return EvalResult( + args, None, "FileParsingError: Unknown error while parsing attachments" + ) log_lines = nsj_log.read().decode("utf-8").splitlines() if not log_lines and returncode == 255: diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py index 1221df0..bac299d 100644 --- a/snekbox/utils/timed.py +++ b/snekbox/utils/timed.py @@ -1,12 +1,14 @@ """Calling functions with time limits.""" import multiprocessing -from collections.abc import Callable, Iterable, Mapping +import signal +from collections.abc import Callable, Generator, Iterable, Mapping +from contextlib import contextmanager from typing import Any, TypeVar _T = TypeVar("_T") _V = TypeVar("_V") -__all__ = ("timed",) +__all__ = ("timed", "time_limit") def timed( @@ -35,3 +37,27 @@ def timed( return result.get(timeout) except multiprocessing.TimeoutError as e: raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e + + +@contextmanager +def time_limit(timeout: int | None = None) -> Generator[None, None, None]: + """ + Decorator to call a function with a time limit. Uses SIGALRM, requires a UNIX system. + + Args: + timeout: Timeout limit in seconds. + + Raises: + TimeoutError: If the function call takes longer than `timeout` seconds. + """ + + def signal_handler(signum, frame): + raise TimeoutError(f"time_limit call timed out after {timeout} seconds.") + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(timeout) + + try: + yield + finally: + signal.alarm(0) |