diff options
| -rw-r--r-- | snekbox/memfs.py | 7 | ||||
| -rw-r--r-- | snekbox/nsjail.py | 33 | ||||
| -rw-r--r-- | snekbox/snekio.py | 19 | ||||
| -rw-r--r-- | snekbox/utils/timed.py | 28 | 
4 files changed, 55 insertions, 32 deletions
| diff --git a/snekbox/memfs.py b/snekbox/memfs.py index 727bc4c..03251cd 100644 --- a/snekbox/memfs.py +++ b/snekbox/memfs.py @@ -104,18 +104,17 @@ class MemFS:          return folder      def attachments( -        self, max_count: int, max_size: int | None = None +        self, max_count: int, pattern: str = "output*"      ) -> Generator[FileAttachment, None, None]:          """Return a list of attachments in the tempdir."""          count = 0 -        # Look for any file starting with `output` -        for file in self.home.glob("output*"): +        for file in self.home.glob(pattern):              if count > max_count:                  log.info(f"Max attachments {max_count} reached, skipping remaining files")                  break              if file.is_file():                  count += 1 -                yield FileAttachment.from_path(file, max_size) +                yield FileAttachment.from_path(file)      def cleanup(self) -> None:          """Unmounts tmpfs, releases name.""" diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 69b6599..d4255e6 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -14,7 +14,8 @@ from snekbox.memfs import MemFS  __all__ = ("NsJail",)  from snekbox.process import EvalResult -from snekbox.snekio import AttachmentError, FileAttachment +from snekbox.snekio import FileAttachment +from snekbox.utils.timed import timed  log = logging.getLogger(__name__) @@ -24,6 +25,15 @@ LOG_PATTERN = re.compile(  ) +def parse_files( +    fs: MemFS, +    files_limit: int, +    files_pattern: str, +) -> list[FileAttachment]: +    """Parse files in a MemFS.""" +    return sorted(fs.attachments(files_limit, files_pattern), key=lambda file: file.name) + +  class NsJail:      """      Core Snekbox functionality, providing safe execution of Python code. @@ -38,16 +48,18 @@ class NsJail:          max_output_size: int = 1_000_000,          read_chunk_size: int = 10_000,          memfs_instance_size: int = 48 * 1024 * 1024, -        max_attachments: int = 100, -        max_attachment_size: int | None = None, +        files_limit: int = 100, +        files_timeout: float = 15, +        files_pattern: str = "output*",      ):          self.nsjail_path = nsjail_path          self.config_path = config_path          self.max_output_size = max_output_size          self.read_chunk_size = read_chunk_size          self.memfs_instance_size = memfs_instance_size -        self.max_attachments = max_attachments -        self.max_attachment_size = max_attachment_size +        self.files_limit = files_limit +        self.files_timeout = files_timeout +        self.files_pattern = files_pattern          self.config = self._read_config(config_path)          self.cgroup_version = utils.cgroup.init(self.config) @@ -227,13 +239,14 @@ class NsJail:              # Parse attachments              try:                  # Sort attachments by name lexically -                attachments = sorted( -                    fs.attachments(self.max_attachments, self.max_attachment_size), -                    key=lambda a: a.name, +                attachments = timed( +                    parse_files, +                    (fs, self.files_limit, self.files_pattern), +                    timeout=self.files_timeout,                  )                  log.info(f"Found {len(attachments)} attachments.") -            except AttachmentError as err: -                log.info(f"Failed to parse attachments: {err}") +            except TimeoutError as err: +                log.info(f"Exceeded time limit in parsing attachments: {err}")                  return EvalResult(args, returncode, f"AttachmentError: {err}")              log_lines = nsj_log.read().decode("utf-8").splitlines() diff --git a/snekbox/snekio.py b/snekbox/snekio.py index 020ca8c..d3b3dce 100644 --- a/snekbox/snekio.py +++ b/snekbox/snekio.py @@ -7,20 +7,9 @@ from dataclasses import dataclass  from pathlib import Path  from typing import Generic, TypeVar -RequestType = dict[str, str | bool | list[str | dict[str, str]]] -  T = TypeVar("T", str, bytes) -def sizeof_fmt(num: int, suffix: str = "B") -> str: -    """Return a human-readable file size.""" -    for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): -        if abs(num) < 1024: -            return f"{num:3.1f}{unit}{suffix}" -        num /= 1024 -    return f"{num:.1f}Yi{suffix}" - -  class AttachmentError(ValueError):      """Raised when an attachment is invalid.""" @@ -64,14 +53,8 @@ class FileAttachment(Generic[T]):          return cls(name, content)      @classmethod -    def from_path(cls, file: Path, max_size: int | None = None) -> FileAttachment[bytes]: +    def from_path(cls, file: Path) -> FileAttachment[bytes]:          """Create an attachment from a file path.""" -        size = file.stat().st_size -        if max_size is not None and size > max_size: -            raise AttachmentError( -                f"File {file.name} too large: {sizeof_fmt(size)} " -                f"exceeds the limit of {sizeof_fmt(max_size)}" -            )          return cls(file.name, file.read_bytes())      @property diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py new file mode 100644 index 0000000..e1602b5 --- /dev/null +++ b/snekbox/utils/timed.py @@ -0,0 +1,28 @@ +"""Calling functions with time limits.""" +from collections.abc import Callable, Iterable, Mapping +from multiprocessing import Pool +from typing import Any, TypeVar + +_T = TypeVar("_T") +_V = TypeVar("_V") + +__all__ = ("timed",) + + +def timed( +    func: Callable[[_T], _V], +    args: Iterable = (), +    kwds: Mapping[str, Any] | None = None, +    timeout: float | None = None, +) -> _V: +    """ +    Call a function with a time limit. + +    Raises: +        TimeoutError: If the function call takes longer than `timeout` seconds. +    """ +    if kwds is None: +        kwds = {} +    with Pool(1) as pool: +        result = pool.apply_async(func, args, kwds) +        return result.get(timeout) | 
