aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--snekbox/memfs.py31
-rw-r--r--snekbox/nsjail.py33
-rw-r--r--snekbox/utils/timed.py41
-rw-r--r--tests/test_nsjail.py23
-rw-r--r--tests/test_timed.py30
5 files changed, 117 insertions, 41 deletions
diff --git a/snekbox/memfs.py b/snekbox/memfs.py
index ddea9a9..991766b 100644
--- a/snekbox/memfs.py
+++ b/snekbox/memfs.py
@@ -1,7 +1,9 @@
"""Memory filesystem for snekbox."""
from __future__ import annotations
+import glob
import logging
+import time
import warnings
import weakref
from collections.abc import Generator
@@ -125,6 +127,7 @@ class MemFS:
limit: int,
pattern: str = "**/*",
exclude_files: dict[Path, float] | None = None,
+ timeout: float | None = None,
) -> Generator[FileAttachment, None, None]:
"""
Yields FileAttachments for files found in the output directory.
@@ -135,12 +138,18 @@ class MemFS:
exclude_files: A dict of Paths and last modified times.
Files will be excluded if their last modified time
is equal to the provided value.
+ timeout: Maximum time in seconds for file parsing.
+ Raises:
+ TimeoutError: If file parsing exceeds timeout.
"""
+ start_time = time.monotonic()
count = 0
- for file in self.output.rglob(pattern):
- # Ignore hidden directories or files
- if any(part.startswith(".") for part in file.parts):
- log.info(f"Skipping hidden path {file!s}")
+ files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
+ for file in (Path(self.output, f) for f in files):
+ if timeout and (time.monotonic() - start_time) > timeout:
+ raise TimeoutError("File parsing timeout exceeded in MemFS.files")
+
+ if not file.is_file():
continue
if exclude_files and (orig_time := exclude_files.get(file)):
@@ -154,10 +163,9 @@ class MemFS:
log.info(f"Max attachments {limit} reached, skipping remaining files")
break
- if file.is_file():
- count += 1
- log.info(f"Found valid file for upload {file.name!r}")
- yield FileAttachment.from_path(file, relative_to=self.output)
+ count += 1
+ log.info(f"Found valid file for upload {file.name!r}")
+ yield FileAttachment.from_path(file, relative_to=self.output)
def files_list(
self,
@@ -165,6 +173,7 @@ class MemFS:
pattern: str,
exclude_files: dict[Path, float] | None = None,
preload_dict: bool = False,
+ timeout: float | None = None,
) -> list[FileAttachment]:
"""
Return a sorted list of file paths within the output directory.
@@ -176,15 +185,21 @@ class MemFS:
Files will be excluded if their last modified time
is equal to the provided value.
preload_dict: Whether to preload as_dict property data.
+ timeout: Maximum time in seconds for file parsing.
Returns:
List of FileAttachments sorted lexically by path name.
+ Raises:
+ TimeoutError: If file parsing exceeds timeout.
"""
+ start_time = time.monotonic()
res = sorted(
self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
key=lambda f: f.path,
)
if preload_dict:
for file in res:
+ if timeout and (time.monotonic() - start_time) > timeout:
+ raise TimeoutError("File parsing timeout exceeded in MemFS.files_list")
# Loads the cached property as attribute
_ = file.as_dict
return res
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index f014850..f64830a 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -15,7 +15,7 @@ from snekbox.filesystem import Size
from snekbox.memfs import MemFS
from snekbox.process import EvalResult
from snekbox.snekio import FileAttachment
-from snekbox.utils.timed import timed
+from snekbox.utils.timed import time_limit
__all__ = ("NsJail",)
@@ -56,7 +56,7 @@ class NsJail:
memfs_home: str = "home",
memfs_output: str = "home",
files_limit: int | None = 100,
- files_timeout: float | None = 8,
+ files_timeout: int | None = 5,
files_pattern: str = "**/[!_]*",
):
"""
@@ -267,21 +267,32 @@ class NsJail:
# Parse attachments with time limit
try:
- attachments = timed(
- MemFS.files_list,
- (fs, self.files_limit, self.files_pattern),
- {
- "preload_dict": True,
- "exclude_files": files_written,
- },
- timeout=self.files_timeout,
- )
+ with time_limit(self.files_timeout):
+ attachments = fs.files_list(
+ limit=self.files_limit,
+ pattern=self.files_pattern,
+ preload_dict=True,
+ exclude_files=files_written,
+ timeout=self.files_timeout,
+ )
log.info(f"Found {len(attachments)} files.")
+ except RecursionError:
+ log.info("Recursion error while parsing attachments")
+ return EvalResult(
+ args,
+ None,
+ "FileParsingError: Exceeded directory depth limit while parsing attachments",
+ )
except TimeoutError as e:
log.info(f"Exceeded time limit while parsing attachments: {e}")
return EvalResult(
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
)
+ except Exception as e:
+ log.exception(f"Unexpected {type(e).__name__} while parse attachments", exc_info=e)
+ return EvalResult(
+ args, None, "FileParsingError: Unknown error while parsing attachments"
+ )
log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py
index 02388ff..11d126c 100644
--- a/snekbox/utils/timed.py
+++ b/snekbox/utils/timed.py
@@ -1,37 +1,34 @@
"""Calling functions with time limits."""
-import multiprocessing
-from collections.abc import Callable, Iterable, Mapping
-from typing import Any, TypeVar
+import signal
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import TypeVar
_T = TypeVar("_T")
_V = TypeVar("_V")
-__all__ = ("timed",)
+__all__ = ("time_limit",)
-def timed(
- func: Callable[[_T], _V],
- args: Iterable = (),
- kwds: Mapping[str, Any] | None = None,
- timeout: float | None = None,
-) -> _V:
+@contextmanager
+def time_limit(timeout: int | None = None) -> Generator[None, None, None]:
"""
- Call a function with a time limit.
+ Decorator to call a function with a time limit.
Args:
- func: Function to call.
- args: Arguments for function.
- kwds: Keyword arguments for function.
timeout: Timeout limit in seconds.
Raises:
TimeoutError: If the function call takes longer than `timeout` seconds.
"""
- if kwds is None:
- kwds = {}
- with multiprocessing.Pool(1) as pool:
- result = pool.apply_async(func, args, kwds)
- try:
- return result.get(timeout)
- except multiprocessing.TimeoutError as e:
- raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e
+
+ def signal_handler(_signum, _frame):
+ raise TimeoutError(f"time_limit call timed out after {timeout} seconds.")
+
+ signal.signal(signal.SIGALRM, signal_handler)
+ signal.alarm(timeout)
+
+ try:
+ yield
+ finally:
+ signal.alarm(0)
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 456046b..c701d3a 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -227,6 +227,29 @@ class NsJailTests(unittest.TestCase):
)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_depth_limit(self):
+ code = dedent(
+ """
+ import os
+
+ x = ""
+ for _ in range(1000):
+ x += "a/"
+ os.mkdir(x)
+
+ open(f"{x}test.txt", "w").write("test")
+ """
+ ).strip()
+
+ nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=5)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, None)
+ self.assertEqual(
+ result.stdout,
+ "FileParsingError: Exceeded directory depth limit while parsing attachments",
+ )
+ self.assertEqual(result.stderr, None)
+
def test_file_write_error(self):
"""Test errors during file write."""
result = self.nsjail.python3(
diff --git a/tests/test_timed.py b/tests/test_timed.py
new file mode 100644
index 0000000..e46bd37
--- /dev/null
+++ b/tests/test_timed.py
@@ -0,0 +1,30 @@
+import math
+import time
+from unittest import TestCase
+
+from snekbox.utils.timed import time_limit
+
+
+class TimedTests(TestCase):
+ def test_sleep(self):
+ """Test that a sleep can be interrupted."""
+ _finished = False
+ start = time.perf_counter()
+ with self.assertRaises(TimeoutError):
+ with time_limit(1):
+ time.sleep(2)
+ _finished = True
+ end = time.perf_counter()
+ self.assertLess(end - start, 2)
+ self.assertFalse(_finished)
+
+ def test_iter(self):
+ """Test that a long-running built-in function can be interrupted."""
+ _result = 0
+ start = time.perf_counter()
+ with self.assertRaises(TimeoutError):
+ with time_limit(1):
+ _result = math.factorial(2**30)
+ end = time.perf_counter()
+ self.assertEqual(_result, 0)
+ self.assertLess(end - start, 2)