diff options
| -rw-r--r-- | README.md | 21 | ||||
| -rw-r--r-- | config/snekbox.cfg | 8 | ||||
| -rw-r--r-- | docker-compose.yml | 2 | ||||
| -rw-r--r-- | snekbox/__init__.py | 2 | ||||
| -rw-r--r-- | snekbox/__main__.py | 2 | ||||
| -rw-r--r-- | snekbox/api/resources/eval.py | 82 | ||||
| -rw-r--r-- | snekbox/filesystem.py | 88 | ||||
| -rw-r--r-- | snekbox/memfs.py | 185 | ||||
| -rw-r--r-- | snekbox/nsjail.py | 137 | ||||
| -rw-r--r-- | snekbox/process.py | 32 | ||||
| -rw-r--r-- | snekbox/snekio.py | 99 | ||||
| -rw-r--r-- | snekbox/utils/__init__.py | 4 | ||||
| -rw-r--r-- | snekbox/utils/timed.py | 37 | ||||
| -rw-r--r-- | tests/api/__init__.py | 4 | ||||
| -rw-r--r-- | tests/api/test_eval.py | 106 | ||||
| -rw-r--r-- | tests/test_filesystem.py | 143 | ||||
| -rw-r--r-- | tests/test_integration.py | 81 | ||||
| -rw-r--r-- | tests/test_main.py | 2 | ||||
| -rw-r--r-- | tests/test_memfs.py | 65 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 170 | ||||
| -rw-r--r-- | tests/test_snekio.py | 58 | 
21 files changed, 1231 insertions, 97 deletions
| @@ -7,6 +7,9 @@  Python sandbox runners for executing code in isolation aka snekbox. +Supports a memory [virtual read/write file system](#virtual-file-system) within the sandbox, +allowing text or binary files to be sent and returned. +  A client sends Python code to a snekbox, the snekbox executes the code, and finally the results of the execution are returned to the client.  ```mermaid @@ -60,10 +63,26 @@ The main features of the default configuration are:  * Memory limit  * Process count limit  * No networking -* Restricted, read-only filesystem +* Restricted, read-only system filesystem +* Memory-based read-write filesystem mounted as working directory `/home`  NsJail is configured through [`snekbox.cfg`]. It contains the exact values for the items listed above. The configuration format is defined by a [protobuf file][7] which can be referred to for documentation. The command-line options of NsJail can also serve as documentation since they closely follow the config file format. +### Memory File System + +On each execution, the host will mount an instance-specific `tmpfs` drive, this is used as a limited read-write folder for the sandboxed code. There is no access to other files or directories on the host container beyond the other read-only mounted system folders. Instance file systems are isolated; it is not possible for sandboxed code to access another instance's writeable directory. + +The following options for the memory file system are configurable as options in [gunicorn.conf.py](config/gunicorn.conf.py) + +* `memfs_instance_size` Size in bytes for the capacity of each instance file system. +* `memfs_home` Path to the home directory within the instance file system. +* `memfs_output` Path to the output directory within the instance file system. +* `files_limit` Maximum number of valid output files to parse. +* `files_timeout` Maximum time in seconds for output file parsing and encoding. +* `files_pattern` Glob pattern to match files within `output`. + +The sandboxed code execution will start with a writeable working directory of `home`. By default, the output folder is also `home`. New files, and uploaded files with a newer last modified time, will be uploaded on completion. +  ### Gunicorn  [Gunicorn settings] can be found in [`gunicorn.conf.py`]. In the default configuration, the worker count, the bind address, and the WSGI app URI are likely the only things of any interest. Since it uses the default synchronous workers, the [worker count] effectively determines how many concurrent code evaluations can be performed. diff --git a/config/snekbox.cfg b/config/snekbox.cfg index 87c216e..5dd63da 100644 --- a/config/snekbox.cfg +++ b/config/snekbox.cfg @@ -3,7 +3,7 @@ description: "Execute Python"  mode: ONCE  hostname: "snekbox" -cwd: "/snekbox" +cwd: "/home"  time_limit: 6 @@ -16,10 +16,12 @@ envar: "VECLIB_MAXIMUM_THREADS=5"  envar: "NUMEXPR_NUM_THREADS=5"  envar: "PYTHONPATH=/snekbox/user_base/lib/python3.11/site-packages"  envar: "PYTHONIOENCODING=utf-8:strict" +envar: "HOME=home"  keep_caps: false  rlimit_as: 700 +rlimit_fsize_type: INF  clone_newnet: true  clone_newuser: true @@ -108,12 +110,12 @@ cgroup_mem_max: 52428800  cgroup_mem_swap_max: 0  cgroup_mem_mount: "/sys/fs/cgroup/memory" -cgroup_pids_max: 5 +cgroup_pids_max: 6  cgroup_pids_mount: "/sys/fs/cgroup/pids"  iface_no_lo: true  exec_bin {      path: "/usr/local/bin/python" -    arg: "-Squ" +    arg: "-BSqu"  } diff --git a/docker-compose.yml b/docker-compose.yml index 9d3ae71..0613abc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ services:      image: ghcr.io/python-discord/snekbox${IMAGE_SUFFIX:--venv:dev}      pull_policy: never      ports: -     - 8060:8060 +     - "8060:8060"      init: true      ipc: none      tty: true diff --git a/snekbox/__init__.py b/snekbox/__init__.py index bed3692..b45960b 100644 --- a/snekbox/__init__.py +++ b/snekbox/__init__.py @@ -12,7 +12,7 @@ from snekbox.api import SnekAPI  # noqa: E402  from snekbox.nsjail import NsJail  # noqa: E402  from snekbox.utils.logging import init_logger, init_sentry  # noqa: E402 -__all__ = ("NsJail", "SnekAPI") +__all__ = ("NsJail", "SnekAPI", "DEBUG")  init_sentry(__version__)  init_logger(DEBUG) diff --git a/snekbox/__main__.py b/snekbox/__main__.py index 2382c4c..239f5d5 100644 --- a/snekbox/__main__.py +++ b/snekbox/__main__.py @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace:  def main() -> None:      """Evaluate Python code through NsJail."""      args = parse_args() -    result = NsJail().python3(args.code, nsjail_args=args.nsjail_args, py_args=args.py_args) +    result = NsJail().python3(py_args=[*args.py_args, args.code], nsjail_args=args.nsjail_args)      print(result.stdout)      if result.returncode != 0: diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py index 1df6c1b..9a53577 100644 --- a/snekbox/api/resources/eval.py +++ b/snekbox/api/resources/eval.py @@ -1,3 +1,5 @@ +from __future__ import annotations +  import logging  import falcon @@ -7,6 +9,8 @@ from snekbox.nsjail import NsJail  __all__ = ("EvalResource",) +from snekbox.snekio import FileAttachment, ParsingError +  log = logging.getLogger(__name__) @@ -25,8 +29,26 @@ class EvalResource:          "properties": {              "input": {"type": "string"},              "args": {"type": "array", "items": {"type": "string"}}, +            "files": { +                "type": "array", +                "items": { +                    "type": "object", +                    "properties": { +                        "path": { +                            "type": "string", +                            # Disallow starting with / or containing \0 anywhere +                            "pattern": r"^(?!/)(?!.*\\0).*$", +                        }, +                        "content": {"type": "string"}, +                    }, +                    "required": ["path"], +                }, +            },          }, -        "required": ["input"], +        "anyOf": [ +            {"required": ["input"]}, +            {"required": ["args"]}, +        ],      }      def __init__(self, nsjail: NsJail): @@ -38,8 +60,11 @@ class EvalResource:          Evaluate Python code and return stdout, stderr, and the return code.          A list of arguments for the Python subprocess can be specified as `args`. -        Otherwise, the default argument "-c" is used to execute the input code. -        The input code is always passed as the last argument to Python. + +        If `input` is specified, it will be appended as the last argument to `args`, +        and `args` will have a default argument of `"-c"`. + +        Either `input` or `args` must be specified.          The return codes mostly resemble those of a Unix shell. Some noteworthy cases: @@ -53,15 +78,35 @@ class EvalResource:          Request body:          >>> { -        ...     "input": "[i for i in range(1000)]", -        ...     "args": ["-m", "timeit"] # This is optional +        ...    "input": "print('Hello')" +        ... } + +        >>> { +        ...    "args": ["-c", "print('Hello')"] +        ... } + +        >>> { +        ...    "args": ["main.py"], +        ...    "files": [ +        ...        { +        ...            "path": "main.py", +        ...            "content": "SGVsbG8...="  # Base64 +        ...        } +        ...    ]          ... }          Response format:          >>> { -        ...     "stdout": "10000 loops, best of 5: 23.8 usec per loop\n", -        ...     "returncode": 0 +        ...     "stdout": "10000 loops, best of 5: 23.8 usec per loop", +        ...     "returncode": 0, +        ...     "files": [ +        ...         { +        ...             "path": "output.png", +        ...             "size": 57344, +        ...             "content": "eJzzSM3...="  # Base64 +        ...         } +        ...     ]          ... }          Status codes: @@ -69,17 +114,28 @@ class EvalResource:          - 200              Successful evaluation; not indicative that the input code itself works          - 400 -           Input's JSON schema is invalid +           Input JSON schema is invalid          - 415              Unsupported content type; only application/JSON is supported          """ -        code = req.media["input"] -        args = req.media.get("args", ("-c",)) - +        body: dict[str, str | list[str] | list[dict[str, str]]] = req.media +        # If `input` is supplied, default `args` to `-c` +        if "input" in body: +            body.setdefault("args", ["-c"]) +            body["args"].append(body["input"])          try: -            result = self.nsjail.python3(code, py_args=args) +            result = self.nsjail.python3( +                py_args=body["args"], +                files=[FileAttachment.from_dict(file) for file in body.get("files", [])], +            ) +        except ParsingError as e: +            raise falcon.HTTPBadRequest(title="Request file is invalid", description=str(e))          except Exception:              log.exception("An exception occurred while trying to process the request")              raise falcon.HTTPInternalServerError -        resp.media = {"stdout": result.stdout, "returncode": result.returncode} +        resp.media = { +            "stdout": result.stdout, +            "returncode": result.returncode, +            "files": [f.as_dict for f in result.files], +        } diff --git a/snekbox/filesystem.py b/snekbox/filesystem.py new file mode 100644 index 0000000..312707c --- /dev/null +++ b/snekbox/filesystem.py @@ -0,0 +1,88 @@ +"""Mounts and unmounts filesystems.""" +from __future__ import annotations + +import ctypes +import os +from ctypes.util import find_library +from enum import IntEnum +from pathlib import Path + +__all__ = ("mount", "unmount", "Size", "UnmountFlags") + +libc = ctypes.CDLL(find_library("c"), use_errno=True) +libc.mount.argtypes = ( +    ctypes.c_char_p, +    ctypes.c_char_p, +    ctypes.c_char_p, +    ctypes.c_ulong, +    ctypes.c_char_p, +) +libc.umount2.argtypes = (ctypes.c_char_p, ctypes.c_int) + + +class Size(IntEnum): +    """Size multipliers for bytes.""" + +    KiB = 1024 +    MiB = 1024**2 +    GiB = 1024**3 +    TiB = 1024**4 + + +class UnmountFlags(IntEnum): +    """Flags for umount2.""" + +    MNT_FORCE = 1 +    MNT_DETACH = 2 +    MNT_EXPIRE = 4 +    UMOUNT_NOFOLLOW = 8 + + +def mount(source: Path | str, target: Path | str, fs: str, **options: str | int) -> None: +    """ +    Mount a filesystem. + +    https://man7.org/linux/man-pages/man8/mount.8.html + +    Args: +        source: Source directory or device. +        target: Target directory. +        fs: Filesystem type. +        **options: Mount options. + +    Raises: +        OSError: On any mount error. +    """ +    if Path(target).is_mount(): +        raise OSError(f"{target} is already a mount point") + +    kwargs = " ".join(f"{key}={value}" for key, value in options.items()) + +    result: int = libc.mount( +        str(source).encode(), str(target).encode(), fs.encode(), 0, kwargs.encode() +    ) +    if result < 0: +        errno = ctypes.get_errno() +        raise OSError(errno, f"Error mounting {target}: {os.strerror(errno)}") + + +def unmount(target: Path | str, flags: UnmountFlags | int = UnmountFlags.MNT_DETACH) -> None: +    """ +    Unmount a filesystem. + +    https://man7.org/linux/man-pages/man2/umount.2.html + +    Args: +        target: Target directory. +        flags: Unmount flags. + +    Raises: +        OSError: On any unmount error. +    """ +    if not Path(target).is_mount(): +        raise OSError(f"{target} is not a mount point") + +    result: int = libc.umount2(str(target).encode(), int(flags)) +    if result < 0: +        errno = ctypes.get_errno() +        raise OSError(errno, f"Error unmounting {target}: {os.strerror(errno)}") diff --git a/snekbox/memfs.py b/snekbox/memfs.py new file mode 100644 index 0000000..f32fed1 --- /dev/null +++ b/snekbox/memfs.py @@ -0,0 +1,185 @@ +"""Memory filesystem for snekbox.""" +from __future__ import annotations + +import logging +import warnings +import weakref +from collections.abc import Generator +from contextlib import suppress +from pathlib import Path +from types import TracebackType +from typing import Type +from uuid import uuid4 + +from snekbox.filesystem import mount, unmount +from snekbox.snekio import FileAttachment + +log = logging.getLogger(__name__) + +__all__ = ("MemFS",) + + +class MemFS: +    """An in-memory temporary file system.""" + +    def __init__( +        self, +        instance_size: int, +        root_dir: str | Path = "/memfs", +        home: str = "home", +        output: str = "home", +    ) -> None: +        """ +        Initialize an in-memory temporary file system. + +        Examples: +            >>> with MemFS(1024) as memfs: +            ...     (memfs.home / "test.txt").write_text("Hello") + +        Args: +            instance_size: Size limit of each tmpfs instance in bytes. +            root_dir: Root directory to mount instances in. +            home: Name of the home directory. +            output: Name of the output directory within home. If empty, uses home. +        """ +        self.instance_size = instance_size +        self.root_dir = Path(root_dir) +        self.root_dir.mkdir(exist_ok=True, parents=True) +        self._home_name = home +        self._output_name = output + +        for _ in range(10): +            name = str(uuid4()) +            try: +                self.path = self.root_dir / name +                self.path.mkdir() +                mount("", self.path, "tmpfs", size=self.instance_size) +                break +            except OSError: +                continue +        else: +            raise RuntimeError("Failed to generate a unique MemFS name in 10 attempts") + +        self.mkdir(self.home) +        self.mkdir(self.output) + +        self._finalizer = weakref.finalize( +            self, +            self._cleanup, +            self.path, +            warn_message=f"Implicitly cleaning up {self!r}", +        ) + +    @classmethod +    def _cleanup(cls, path: Path, warn_message: str): +        """Implicit cleanup of the MemFS.""" +        with suppress(OSError): +            unmount(path) +            path.rmdir() +        warnings.warn(warn_message, ResourceWarning) + +    def cleanup(self) -> None: +        """Unmount the tempfs and remove the directory.""" +        if self._finalizer.detach() or self.path.exists(): +            unmount(self.path) +            self.path.rmdir() + +    @property +    def name(self) -> str: +        """Name of the temp dir.""" +        return self.path.name + +    @property +    def home(self) -> Path: +        """Path to home directory.""" +        return self.path / self._home_name + +    @property +    def output(self) -> Path: +        """Path to output directory.""" +        return self.path / self._output_name + +    def __enter__(self) -> MemFS: +        return self + +    def __exit__( +        self, +        exc_type: Type[BaseException] | None, +        exc_val: BaseException | None, +        exc_tb: TracebackType | None, +    ) -> None: +        self.cleanup() + +    def __repr__(self) -> str: +        return f"<{self.__class__.__name__} {self.path}>" + +    def mkdir(self, path: Path | str, chmod: int = 0o777) -> Path: +        """Create a directory in the tempdir.""" +        folder = Path(self.path, path) +        folder.mkdir(parents=True, exist_ok=True) +        folder.chmod(chmod) +        return folder + +    def files( +        self, +        limit: int, +        pattern: str = "**/*", +        exclude_files: dict[Path, float] | None = None, +    ) -> Generator[FileAttachment, None, None]: +        """ +        Yields FileAttachments for files found in the output directory. + +        Args: +            limit: The maximum number of files to parse. +            pattern: The glob pattern to match files against. +            exclude_files: A dict of Paths and last modified times. +                Files will be excluded if their last modified time +                is equal to the provided value. +        """ +        count = 0 +        for file in self.output.rglob(pattern): +            if exclude_files and (orig_time := exclude_files.get(file)): +                new_time = file.stat().st_mtime +                log.info(f"Checking {file.name} ({orig_time=}, {new_time=})") +                if file.stat().st_mtime == orig_time: +                    log.info(f"Skipping {file.name!r} as it has not been modified") +                    continue + +            if count > limit: +                log.info(f"Max attachments {limit} reached, skipping remaining files") +                break + +            if file.is_file(): +                count += 1 +                log.info(f"Found valid file for upload {file.name!r}") +                yield FileAttachment.from_path(file, relative_to=self.output) + +    def files_list( +        self, +        limit: int, +        pattern: str, +        exclude_files: dict[Path, float] | None = None, +        preload_dict: bool = False, +    ) -> list[FileAttachment]: +        """ +        Return a sorted list of file paths within the output directory. + +        Args: +            limit: The maximum number of files to parse. +            pattern: The glob pattern to match files against. +            exclude_files: A dict of Paths and last modified times. +                Files will be excluded if their last modified time +                is equal to the provided value. +            preload_dict: Whether to preload as_dict property data. +        Returns: +            List of FileAttachments sorted lexically by path name. +        """ +        res = sorted( +            self.files(limit=limit, pattern=pattern, exclude_files=exclude_files), +            key=lambda f: f.path, +        ) +        if preload_dict: +            for file in res: +                # Loads the cached property as attribute +                _ = file.as_dict +        return res diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 63afdef..f014850 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -2,26 +2,43 @@ import logging  import re  import subprocess  import sys -import textwrap -from subprocess import CompletedProcess +from collections.abc import Generator +from pathlib import Path  from tempfile import NamedTemporaryFile -from typing import Iterable +from typing import Iterable, TypeVar  from google.protobuf import text_format  from snekbox import DEBUG, utils  from snekbox.config_pb2 import NsJailConfig +from snekbox.filesystem import Size +from snekbox.memfs import MemFS +from snekbox.process import EvalResult +from snekbox.snekio import FileAttachment +from snekbox.utils.timed import timed  __all__ = ("NsJail",)  log = logging.getLogger(__name__) +_T = TypeVar("_T") +  # [level][timestamp][PID]? function_signature:line_no? message  LOG_PATTERN = re.compile(      r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)"  ) +def iter_lstrip(iterable: Iterable[_T]) -> Generator[_T, None, None]: +    """Remove leading falsy objects from an iterable.""" +    it = iter(iterable) +    for item in it: +        if item: +            yield item +            break +    yield from it + +  class NsJail:      """      Core Snekbox functionality, providing safe execution of Python code. @@ -35,12 +52,41 @@ class NsJail:          config_path: str = "./config/snekbox.cfg",          max_output_size: int = 1_000_000,          read_chunk_size: int = 10_000, +        memfs_instance_size: int = 48 * Size.MiB, +        memfs_home: str = "home", +        memfs_output: str = "home", +        files_limit: int | None = 100, +        files_timeout: float | None = 8, +        files_pattern: str = "**/[!_]*",      ): +        """ +        Initialize NsJail. + +        Args: +            nsjail_path: Path to the NsJail binary. +            config_path: Path to the NsJail configuration file. +            max_output_size: Maximum size of the output in bytes. +            read_chunk_size: Size of the read buffer in bytes. +            memfs_instance_size: Size of the tmpfs instance in bytes. +            memfs_home: Name of the mounted home directory. +            memfs_output: Name of the output directory within home, +                can be empty to use home as output. +            files_limit: Maximum number of output files to parse. +            files_timeout: Maximum time in seconds to wait for output files to be read. +            files_pattern: Pattern to match files to attach within the output directory. +        """          self.nsjail_path = nsjail_path          self.config_path = config_path          self.max_output_size = max_output_size          self.read_chunk_size = read_chunk_size +        self.memfs_instance_size = memfs_instance_size +        self.memfs_home = memfs_home +        self.memfs_output = memfs_output +        self.files_limit = files_limit +        self.files_timeout = files_timeout +        self.files_pattern = files_pattern +          self.config = self._read_config(config_path)          self.cgroup_version = utils.cgroup.init(self.config)          self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version) @@ -129,16 +175,18 @@ class NsJail:          return "".join(output)      def python3( -        self, code: str, *, nsjail_args: Iterable[str] = (), py_args: Iterable[str] = ("-c",) -    ) -> CompletedProcess: +        self, +        py_args: Iterable[str], +        files: Iterable[FileAttachment] = (), +        nsjail_args: Iterable[str] = (), +    ) -> EvalResult:          """          Execute Python 3 code in an isolated environment and return the completed process. -        The `nsjail_args` passed will be used to override the values in the NsJail config. -        These arguments are only options for NsJail; they do not affect Python's arguments. - -        `py_args` are arguments to pass to the Python subprocess before the code, -        which is the last argument. By default, it's "-c", which executes the code given. +        Args: +            py_args: Arguments to pass to Python. +            files: FileAttachments to write to the sandbox prior to running Python. +            nsjail_args: Overrides for the NsJail configuration.          """          if self.cgroup_version == 2:              nsjail_args = ("--use_cgroupv2", *nsjail_args) @@ -152,8 +200,19 @@ class NsJail:                  *nsjail_args,              ) -        with NamedTemporaryFile() as nsj_log: -            args = ( +        with NamedTemporaryFile() as nsj_log, MemFS( +            instance_size=self.memfs_instance_size, +            home=self.memfs_home, +            output=self.memfs_output, +        ) as fs: +            nsjail_args = ( +                # Mount `home` with Read/Write access +                "--bindmount", +                f"{fs.home}:home", +                *nsjail_args, +            ) + +            args = [                  self.nsjail_path,                  "--config",                  self.config_path, @@ -163,13 +222,30 @@ class NsJail:                  "--",                  self.config.exec_bin.path,                  *self.config.exec_bin.arg, -                *py_args, -                code, -            ) +                # Filter out empty strings at start of py_args +                # (causes issues with python cli) +                *iter_lstrip(py_args), +            ] + +            # Write provided files if any +            files_written: dict[Path, float] = {} +            for file in files: +                try: +                    f_path = file.save_to(fs.home) +                    # Allow file to be writable +                    f_path.chmod(0o777) +                    # Save the written at time to later check if it was modified +                    files_written[f_path] = f_path.stat().st_mtime +                    log.info(f"Created file at {(fs.home / file.path)!r}.") +                except OSError as e: +                    log.info(f"Failed to create file at {(fs.home / file.path)!r}.", exc_info=e) +                    return EvalResult( +                        args, None, f"{e.__class__.__name__}: Failed to create file '{file.path}'." +                    )              msg = "Executing code..."              if DEBUG: -                msg = f"{msg[:-3]}:\n{textwrap.indent(code, '    ')}\nWith the arguments {args}." +                msg = f"{msg[:-3]} with the arguments {args}."              log.info(msg)              try: @@ -177,23 +253,36 @@ class NsJail:                      args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True                  )              except ValueError: -                return CompletedProcess(args, None, "ValueError: embedded null byte", None) +                return EvalResult(args, None, "ValueError: embedded null byte")              try:                  output = self._consume_stdout(nsjail)              except UnicodeDecodeError: -                return CompletedProcess( -                    args, -                    None, -                    "UnicodeDecodeError: invalid Unicode in output pipe", -                    None, -                ) +                return EvalResult(args, None, "UnicodeDecodeError: invalid Unicode in output pipe")              # When you send signal `N` to a subprocess to terminate it using Popen, it              # will return `-N` as its exit code. As we normally get `N + 128` back, we              # convert negative exit codes to the `N + 128` form.              returncode = -nsjail.returncode + 128 if nsjail.returncode < 0 else nsjail.returncode +            # Parse attachments with time limit +            try: +                attachments = timed( +                    MemFS.files_list, +                    (fs, self.files_limit, self.files_pattern), +                    { +                        "preload_dict": True, +                        "exclude_files": files_written, +                    }, +                    timeout=self.files_timeout, +                ) +                log.info(f"Found {len(attachments)} files.") +            except TimeoutError as e: +                log.info(f"Exceeded time limit while parsing attachments: {e}") +                return EvalResult( +                    args, None, "TimeoutError: Exceeded time limit while parsing attachments" +                ) +              log_lines = nsj_log.read().decode("utf-8").splitlines()              if not log_lines and returncode == 255:                  # NsJail probably failed to parse arguments so log output will still be in stdout @@ -203,4 +292,4 @@ class NsJail:          log.info(f"nsjail return code: {returncode}") -        return CompletedProcess(args, returncode, output, None) +        return EvalResult(args, returncode, output, files=attachments) diff --git a/snekbox/process.py b/snekbox/process.py new file mode 100644 index 0000000..552b91a --- /dev/null +++ b/snekbox/process.py @@ -0,0 +1,32 @@ +"""Utilities for process management.""" +from collections.abc import Sequence +from os import PathLike +from subprocess import CompletedProcess +from typing import TypeVar + +from snekbox.snekio import FileAttachment + +_T = TypeVar("_T") +ArgType = ( +    str +    | bytes +    | PathLike[str] +    | PathLike[bytes] +    | Sequence[str | bytes | PathLike[str] | PathLike[bytes]] +) + + +class EvalResult(CompletedProcess[_T]): +    """An evaluation job that has finished running.""" + +    def __init__( +        self, +        args: ArgType, +        returncode: int | None, +        stdout: _T | None = None, +        stderr: _T | None = None, +        files: list[FileAttachment] | None = None, +    ) -> None: +        """Create an evaluation result.""" +        super().__init__(args, returncode, stdout, stderr) +        self.files: list[FileAttachment] = files or [] diff --git a/snekbox/snekio.py b/snekbox/snekio.py new file mode 100644 index 0000000..821f057 --- /dev/null +++ b/snekbox/snekio.py @@ -0,0 +1,99 @@ +"""I/O Operations for sending / receiving files from the sandbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path + + +def safe_path(path: str) -> str: +    """ +    Return `path` if there are no security issues. + +    Raises: +        IllegalPathError: Raised on any path rule violation. +    """ +    # Disallow absolute paths +    if Path(path).is_absolute(): +        raise IllegalPathError(f"File path '{path}' must be relative") + +    # Disallow traversal beyond root +    try: +        test_root = Path("/home") +        Path(test_root).joinpath(path).resolve().relative_to(test_root.resolve()) +    except ValueError: +        raise IllegalPathError(f"File path '{path}' may not traverse beyond root") + +    return path + + +class ParsingError(ValueError): +    """Raised when an incoming content cannot be parsed.""" + + +class IllegalPathError(ParsingError): +    """Raised when a request file has an illegal path.""" + + +@dataclass(frozen=True) +class FileAttachment: +    """A file attachment.""" + +    path: str +    content: bytes + +    def __repr__(self) -> str: +        path = f"{self.path[:30]}..." if len(self.path) > 30 else self.path +        content = f"{self.content[:15]}..." if len(self.content) > 15 else self.content +        return f"{self.__class__.__name__}(path={path!r}, content={content!r})" + +    @classmethod +    def from_dict(cls, data: dict[str, str]) -> FileAttachment: +        """ +        Convert a dict to an attachment. + +        Raises: +            ParsingError: Raised when the dict has invalid base64 `content`. +        """ +        path = safe_path(data["path"]) +        try: +            content = b64decode(data.get("content", "")) +        except (TypeError, ValueError) as e: +            raise ParsingError(f"Invalid base64 encoding for file '{path}'") from e +        return cls(path, content) + +    @classmethod +    def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment: +        """ +        Create an attachment from a file path. + +        Args: +            file: The file to attach. +            relative_to: The root for the path name. +        """ +        path = file.relative_to(relative_to) if relative_to else file +        return cls(str(path), file.read_bytes()) + +    @property +    def size(self) -> int: +        """Size of the attachment.""" +        return len(self.content) + +    def save_to(self, directory: Path | str) -> Path: +        """Write the attachment to a file in `directory`. Return a Path of the file.""" +        file = Path(directory, self.path) +        # Create directories if they don't exist +        file.parent.mkdir(parents=True, exist_ok=True) +        file.write_bytes(self.content) +        return file + +    @cached_property +    def as_dict(self) -> dict[str, str | int]: +        """Convert the attachment to a dict.""" +        content = b64encode(self.content).decode("ascii") +        return { +            "path": self.path, +            "size": self.size, +            "content": content, +        } diff --git a/snekbox/utils/__init__.py b/snekbox/utils/__init__.py index 6d6bc32..010fa65 100644 --- a/snekbox/utils/__init__.py +++ b/snekbox/utils/__init__.py @@ -1,3 +1,3 @@ -from . import cgroup, logging, swap +from . import cgroup, logging, swap, timed -__all__ = ("cgroup", "logging", "swap") +__all__ = ("cgroup", "logging", "swap", "timed") diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py new file mode 100644 index 0000000..02388ff --- /dev/null +++ b/snekbox/utils/timed.py @@ -0,0 +1,37 @@ +"""Calling functions with time limits.""" +import multiprocessing +from collections.abc import Callable, Iterable, Mapping +from typing import Any, TypeVar + +_T = TypeVar("_T") +_V = TypeVar("_V") + +__all__ = ("timed",) + + +def timed( +    func: Callable[[_T], _V], +    args: Iterable = (), +    kwds: Mapping[str, Any] | None = None, +    timeout: float | None = None, +) -> _V: +    """ +    Call a function with a time limit. + +    Args: +        func: Function to call. +        args: Arguments for function. +        kwds: Keyword arguments for function. +        timeout: Timeout limit in seconds. + +    Raises: +        TimeoutError: If the function call takes longer than `timeout` seconds. +    """ +    if kwds is None: +        kwds = {} +    with multiprocessing.Pool(1) as pool: +        result = pool.apply_async(func, args, kwds) +        try: +            return result.get(timeout) +        except multiprocessing.TimeoutError as e: +            raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 0e6e422..5f20faf 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -1,10 +1,10 @@  import logging -from subprocess import CompletedProcess  from unittest import mock  from falcon import testing  from snekbox.api import SnekAPI +from snekbox.process import EvalResult  class SnekAPITestCase(testing.TestCase): @@ -13,7 +13,7 @@ class SnekAPITestCase(testing.TestCase):          self.patcher = mock.patch("snekbox.api.snekapi.NsJail", autospec=True)          self.mock_nsjail = self.patcher.start() -        self.mock_nsjail.return_value.python3.return_value = CompletedProcess( +        self.mock_nsjail.return_value.python3.return_value = EvalResult(              args=[], returncode=0, stdout="output", stderr="error"          )          self.addCleanup(self.patcher.stop) diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py index 976970e..37f90e7 100644 --- a/tests/api/test_eval.py +++ b/tests/api/test_eval.py @@ -5,12 +5,19 @@ class TestEvalResource(SnekAPITestCase):      PATH = "/eval"      def test_post_valid_200(self): -        body = {"input": "foo"} -        result = self.simulate_post(self.PATH, json=body) - -        self.assertEqual(result.status_code, 200) -        self.assertEqual("output", result.json["stdout"]) -        self.assertEqual(0, result.json["returncode"]) +        cases = [ +            {"args": ["-c", "print('output')"]}, +            {"input": "print('hello')"}, +            {"input": "print('hello')", "args": ["-c"]}, +            {"input": "print('hello')", "args": [""]}, +            {"input": "pass", "args": ["-m", "timeit"]}, +        ] +        for body in cases: +            with self.subTest(): +                result = self.simulate_post(self.PATH, json=body) +                self.assertEqual(result.status_code, 200) +                self.assertEqual("output", result.json["stdout"]) +                self.assertEqual(0, result.json["returncode"])      def test_post_invalid_schema_400(self):          body = {"stuff": "foo"} @@ -20,27 +27,100 @@ class TestEvalResource(SnekAPITestCase):          expected = {              "title": "Request data failed validation", -            "description": "'input' is a required property", +            "description": "{'stuff': 'foo'} is not valid under any of the given schemas",          }          self.assertEqual(expected, result.json)      def test_post_invalid_data_400(self): -        bodies = ({"input": 400}, {"input": "", "args": [400]}) - -        for body in bodies: +        bodies = ({"args": 400}, {"args": [], "files": [215]}) +        expects = ["400 is not of type 'array'", "215 is not of type 'object'"] +        for body, expected in zip(bodies, expects):              with self.subTest():                  result = self.simulate_post(self.PATH, json=body)                  self.assertEqual(result.status_code, 400) -                expected = { +                expected_json = {                      "title": "Request data failed validation", -                    "description": "400 is not of type 'string'", +                    "description": expected, +                } +                self.assertEqual(expected_json, result.json) + +    def test_files_path(self): +        """Normal paths should work with 200.""" +        test_paths = [ +            "file.txt", +            "./0.jpg", +            "path/to/file", +            "folder/../hm", +            "folder/./to/./somewhere", +            "traversal/but/../not/beyond/../root", +            r"backslash\\okay", +            r"backslash\okay", +            "numbers/0123456789", +        ] +        for path in test_paths: +            with self.subTest(path=path): +                body = {"args": ["test.py"], "files": [{"path": path}]} +                result = self.simulate_post(self.PATH, json=body) +                self.assertEqual(result.status_code, 200) +                self.assertEqual("output", result.json["stdout"]) +                self.assertEqual(0, result.json["returncode"]) + +    def test_files_illegal_path_traversal(self): +        """Traversal beyond root should be denied with 400 error.""" +        test_paths = [ +            "../secrets", +            "../../dir", +            "dir/../../secrets", +            "dir/var/../../../file", +        ] +        for path in test_paths: +            with self.subTest(path=path): +                body = {"args": ["test.py"], "files": [{"path": path}]} +                result = self.simulate_post(self.PATH, json=body) +                self.assertEqual(result.status_code, 400) +                expected = { +                    "title": "Request file is invalid", +                    "description": f"File path '{path}' may not traverse beyond root",                  } -                  self.assertEqual(expected, result.json) +    def test_files_illegal_path_absolute(self): +        """Absolute file paths should 400-error at json schema validation stage.""" +        test_paths = [ +            "/", +            "/etc", +            "/etc/vars/secrets", +            "/absolute", +            "/file.bin", +        ] +        for path in test_paths: +            with self.subTest(path=path): +                body = {"args": ["test.py"], "files": [{"path": path}]} +                result = self.simulate_post(self.PATH, json=body) +                self.assertEqual(result.status_code, 400) +                self.assertEqual("Request data failed validation", result.json["title"]) +                self.assertIn("does not match", result.json["description"]) + +    def test_files_illegal_path_null_byte(self): +        """Paths containing \0 should 400-error at json schema validation stage.""" +        test_paths = [ +            r"etc/passwd\0", +            r"a\0b", +            r"\0", +            r"\\0", +            r"var/\0/path", +        ] +        for path in test_paths: +            with self.subTest(path=path): +                body = {"args": ["test.py"], "files": [{"path": path}]} +                result = self.simulate_post(self.PATH, json=body) +                self.assertEqual(result.status_code, 400) +                self.assertEqual("Request data failed validation", result.json["title"]) +                self.assertIn("does not match", result.json["description"]) +      def test_post_invalid_content_type_415(self):          body = "{'input': 'foo'}"          headers = {"Content-Type": "application/xml"} diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000..e4d081f --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,143 @@ +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, suppress +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase +from uuid import uuid4 + +from snekbox.filesystem import UnmountFlags, mount, unmount + + +class LibMountTests(TestCase): +    temp_dir: TemporaryDirectory + +    @classmethod +    def setUpClass(cls): +        cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests") + +    @classmethod +    def tearDownClass(cls): +        cls.temp_dir.cleanup() + +    @contextmanager +    def get_mount(self): +        """Yield a valid mount point and unmount after context.""" +        path = Path(self.temp_dir.name, str(uuid4())) +        path.mkdir() +        try: +            mount(source="", target=path, fs="tmpfs") +            yield path +        finally: +            with suppress(OSError): +                unmount(path) + +    def test_mount(self): +        """Test normal mounting.""" +        with self.get_mount() as path: +            self.assertTrue(path.is_mount()) +            self.assertTrue(path.exists()) +        self.assertFalse(path.is_mount()) +        # Unmounting should not remove the original folder +        self.assertTrue(path.exists()) + +    def test_mount_errors(self): +        """Test invalid mount errors.""" +        cases = [ +            (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"), +            (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"), +            ( +                dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"), +                OSError, +                "Invalid argument", +            ), +        ] +        for case, err, msg in cases: +            with self.subTest(case=case): +                with self.assertRaises(err) as cm: +                    mount(**case) +                self.assertIn(msg, str(cm.exception)) + +    def test_mount_duplicate(self): +        """Test attempted mount after mounted.""" +        path = Path(self.temp_dir.name, str(uuid4())) +        path.mkdir() +        try: +            mount(source="", target=path, fs="tmpfs") +            with self.assertRaises(OSError) as cm: +                mount(source="", target=path, fs="tmpfs") +            self.assertIn("already a mount point", str(cm.exception)) +        finally: +            unmount(target=path) + +    def test_unmount_flags(self): +        """Test unmount flags.""" +        flags = [ +            UnmountFlags.MNT_FORCE, +            UnmountFlags.MNT_DETACH, +            UnmountFlags.UMOUNT_NOFOLLOW, +        ] +        for flag in flags: +            with self.subTest(flag=flag), self.get_mount() as path: +                self.assertTrue(path.is_mount()) +                unmount(path, flag) +                self.assertFalse(path.is_mount()) + +    def test_unmount_flags_expire(self): +        """Test unmount MNT_EXPIRE behavior.""" +        with self.get_mount() as path: +            with self.assertRaises(BlockingIOError): +                unmount(path, UnmountFlags.MNT_EXPIRE) + +    def test_unmount_errors(self): +        """Test invalid unmount errors.""" +        cases = [ +            (dict(target="not/exist"), OSError, "is not a mount point"), +            (dict(target=Path("not/exist")), OSError, "is not a mount point"), +        ] +        for case, err, msg in cases: +            with self.subTest(case=case): +                with self.assertRaises(err) as cm: +                    unmount(**case) +                self.assertIn(msg, str(cm.exception)) + +    def test_unmount_invalid_args(self): +        """Test invalid unmount invalid flag.""" +        with self.get_mount() as path: +            with self.assertRaises(OSError) as cm: +                unmount(path, 251) +            self.assertIn("Invalid argument", str(cm.exception)) + +    def test_threading(self): +        """Test concurrent mounting works in multi-thread environments.""" +        paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)] + +        for path in paths: +            path.mkdir() +            self.assertFalse(path.is_mount()) + +        try: +            with ThreadPoolExecutor() as pool: +                res = list( +                    pool.map( +                        mount, +                        [""] * len(paths), +                        paths, +                        ["tmpfs"] * len(paths), +                    ) +                ) +                self.assertEqual(len(res), len(paths)) + +                for path in paths: +                    with self.subTest(path=path): +                        self.assertTrue(path.is_mount()) + +                unmounts = list(pool.map(unmount, paths)) +                self.assertEqual(len(unmounts), len(paths)) + +                for path in paths: +                    with self.subTest(path=path): +                        self.assertFalse(path.is_mount()) +        finally: +            with suppress(OSError): +                for path in paths: +                    unmount(path) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7c5db2b..91b01e6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,14 +1,25 @@  import json  import unittest  import urllib.request +from base64 import b64encode  from multiprocessing.dummy import Pool +from textwrap import dedent  from tests.gunicorn_utils import run_gunicorn -def run_code_in_snekbox(code: str) -> tuple[str, int]: -    body = {"input": code} -    json_data = json.dumps(body).encode("utf-8") +def b64encode_code(data: str): +    data = dedent(data).strip() +    return b64encode(data.encode()).decode("ascii") + + +def snekbox_run_code(code: str) -> tuple[str, int]: +    body = {"args": ["-c", code]} +    return snekbox_request(body) + + +def snekbox_request(content: dict) -> tuple[str, int]: +    json_data = json.dumps(content).encode("utf-8")      req = urllib.request.Request("http://localhost:8060/eval")      req.add_header("Content-Type", "application/json; charset=utf-8") @@ -34,9 +45,71 @@ class IntegrationTests(unittest.TestCase):              args = [code] * processes              with Pool(processes) as p: -                results = p.map(run_code_in_snekbox, args) +                results = p.map(snekbox_run_code, args)              responses, statuses = zip(*results)              self.assertTrue(all(status == 200 for status in statuses))              self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses)) + +    def test_eval(self): +        """Test normal eval requests without files.""" +        with run_gunicorn(): +            cases = [ +                ({"input": "print('Hello')"}, "Hello\n"), +                ({"args": ["-c", "print('abc12')"]}, "abc12\n"), +            ] +            for body, expected in cases: +                with self.subTest(body=body): +                    response, status = snekbox_request(body) +                    self.assertEqual(status, 200) +                    self.assertEqual(json.loads(response)["stdout"], expected) + +    def test_files_send_receive(self): +        """Test sending and receiving files to snekbox.""" +        with run_gunicorn(): +            request = { +                "args": ["main.py"], +                "files": [ +                    { +                        "path": "main.py", +                        "content": b64encode_code( +                            """ +                            from pathlib import Path +                            from mod import lib +                            print(lib.var) + +                            with open('test.txt', 'w') as f: +                                f.write('test 1') + +                            Path('dir').mkdir() +                            Path('dir/test2.txt').write_text('test 2') +                            """ +                        ), +                    }, +                    {"path": "mod/__init__.py"}, +                    {"path": "mod/lib.py", "content": b64encode_code("var = 'hello'")}, +                ], +            } + +            expected = { +                "stdout": "hello\n", +                "returncode": 0, +                "files": [ +                    { +                        "path": "dir/test2.txt", +                        "size": len("test 2"), +                        "content": b64encode_code("test 2"), +                    }, +                    { +                        "path": "test.txt", +                        "size": len("test 1"), +                        "content": b64encode_code("test 1"), +                    }, +                ], +            } + +            response, status = snekbox_request(request) + +            self.assertEqual(200, status) +            self.assertEqual(expected, json.loads(response)) diff --git a/tests/test_main.py b/tests/test_main.py index 77b3130..24c067c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase):      @patch("sys.argv", ["", "import sys; sys.exit(22)"])      def test_main_exits_with_returncode(self): -        """Should exit with the subprocess's returncode if it's non-zero.""" +        """Should exit with the subprocess returncode if it's non-zero."""          with self.assertRaises(SystemExit) as cm:              snekbox_main.main() diff --git a/tests/test_memfs.py b/tests/test_memfs.py new file mode 100644 index 0000000..0555726 --- /dev/null +++ b/tests/test_memfs.py @@ -0,0 +1,65 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack +from unittest import TestCase, mock +from uuid import uuid4 + +from snekbox.memfs import MemFS + +UUID_TEST = uuid4() + + +class MemFSTests(TestCase): +    def setUp(self): +        super().setUp() +        self.logger = logging.getLogger("snekbox.memfs") +        self.logger.setLevel(logging.WARNING) + +    @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST) +    def test_assignment_thread_safe(self): +        """Test concurrent mounting works in multi-thread environments.""" +        # Concurrently create MemFS in threads, check only 1 can be created +        # Others should result in RuntimeError +        with ExitStack() as stack: +            with ThreadPoolExecutor() as executor: +                memfs: MemFS | None = None +                # Each future uses enter_context to ensure __exit__ on test exception +                futures = [ +                    executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8) +                ] +                for future in futures: +                    # We should have exactly one result and all others RuntimeErrors +                    if err := future.exception(): +                        self.assertIsInstance(err, RuntimeError) +                    else: +                        self.assertIsNone(memfs) +                        memfs = future.result() + +                # Original memfs should still exist afterwards +                self.assertIsInstance(memfs, MemFS) +                self.assertTrue(memfs.path.is_mount()) + +    def test_cleanup(self): +        """Test explicit cleanup.""" +        memfs = MemFS(10) +        path = memfs.path +        self.assertTrue(path.is_mount()) +        memfs.cleanup() +        self.assertFalse(path.exists()) + +    def test_context_cleanup(self): +        """Context __exit__ should trigger cleanup.""" +        with MemFS(10) as memfs: +            path = memfs.path +            self.assertTrue(path.is_mount()) +        self.assertFalse(path.exists()) + +    def test_implicit_cleanup(self): +        """Test implicit _cleanup triggered by GC.""" +        memfs = MemFS(10) +        path = memfs.path +        self.assertTrue(path.is_mount()) +        # Catch the warning about implicit cleanup +        with self.assertWarns(ResourceWarning): +            del memfs +        self.assertFalse(path.exists()) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 6f6e2a7..cad79f3 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -9,28 +9,40 @@ from itertools import product  from pathlib import Path  from textwrap import dedent +from snekbox.filesystem import Size  from snekbox.nsjail import NsJail +from snekbox.snekio import FileAttachment  class NsJailTests(unittest.TestCase):      def setUp(self):          super().setUp() -        self.nsjail = NsJail() +        # Specify lower limits for unit tests to complete within time limits +        self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB)          self.logger = logging.getLogger("snekbox.nsjail")          self.logger.setLevel(logging.WARNING) +    def eval_code(self, code: str): +        return self.nsjail.python3(["-c", code]) + +    def eval_file(self, code: str, name: str = "test.py", **kwargs): +        file = FileAttachment(name, code.encode()) +        return self.nsjail.python3([name], [file], **kwargs) +      def test_print_returns_0(self): -        result = self.nsjail.python3("print('test')") -        self.assertEqual(result.returncode, 0) -        self.assertEqual(result.stdout, "test\n") -        self.assertEqual(result.stderr, None) +        for fn in (self.eval_code, self.eval_file): +            with self.subTest(fn.__name__): +                result = fn("print('test')") +                self.assertEqual(result.returncode, 0) +                self.assertEqual(result.stdout, "test\n") +                self.assertEqual(result.stderr, None)      def test_timeout_returns_137(self):          code = "while True: pass"          with self.assertLogs(self.logger) as log: -            result = self.nsjail.python3(code) +            result = self.eval_file(code)          self.assertEqual(result.returncode, 137)          self.assertEqual(result.stdout, "") @@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase):          # Add a kilobyte just to be safe.          code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" -        result = self.nsjail.python3(code) -        self.assertEqual(result.returncode, 137) +        result = self.eval_file(code)          self.assertEqual(result.stdout, "") +        self.assertEqual(result.returncode, 137) +        self.assertEqual(result.stderr, None) + +    def test_multi_files(self): +        files = [ +            FileAttachment("main.py", "import lib; print(lib.x)".encode()), +            FileAttachment("lib.py", "x = 'hello'".encode()), +        ] + +        result = self.nsjail.python3(["main.py"], files) +        self.assertEqual(result.returncode, 0) +        self.assertEqual(result.stdout, "hello\n")          self.assertEqual(result.stderr, None)      def test_subprocess_resource_unavailable(self): +        max_pids = self.nsjail.config.cgroup_pids_max          code = dedent( -            """ +            f"""              import subprocess -            # Max PIDs is 5. -            for _ in range(6): +            # Should fail at n (max PIDs) since the caller python process counts as well +            for _ in range({max_pids}):                  print(subprocess.Popen(                      [                          '/usr/local/bin/python3', @@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase):              """          ).strip() -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertEqual(result.returncode, 1)          self.assertIn("Resource temporarily unavailable", result.stdout) +        # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n" +        expected = "\n".join(map(str, range(2, max_pids + 1))) +        self.assertIn(expected, result.stdout)          self.assertEqual(result.stderr, None)      def test_multiprocess_resource_limits(self): @@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase):              """          ) -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          exit_codes = result.stdout.strip().split()          self.assertIn("-9", exit_codes) @@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase):                      """                  ).strip() -                result = self.nsjail.python3(code) +                result = self.eval_file(code)                  self.assertEqual(result.returncode, 1)                  self.assertIn("Read-only file system", result.stdout)                  self.assertEqual(result.stderr, None) +    def test_write(self): +        code = dedent( +            """ +            from pathlib import Path +            with open('test.txt', 'w') as f: +                f.write('hello') +            print(Path('test.txt').read_text()) +            """ +        ).strip() + +        result = self.eval_file(code) +        self.assertEqual(result.returncode, 0) +        self.assertEqual(result.stdout, "hello\n") +        self.assertEqual(result.stderr, None) + +    def test_write_exceed_space(self): +        code = dedent( +            f""" +            size = {self.nsjail.memfs_instance_size} // 2048 +            with open('f.bin', 'wb') as f: +                for i in range(size): +                    f.write(b'1' * 2048) +            """ +        ).strip() + +        result = self.eval_file(code) +        self.assertEqual(result.returncode, 1) +        self.assertIn("No space left on device", result.stdout) +        self.assertEqual(result.stderr, None) +      def test_forkbomb_resource_unavailable(self):          code = dedent(              """ @@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase):              """          ).strip() -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertEqual(result.returncode, 1)          self.assertIn("Resource temporarily unavailable", result.stdout)          self.assertEqual(result.stderr, None) +    def test_file_parsing_timeout(self): +        code = dedent( +            """ +            import os +            data = "a" * 1024 +            size = 32 * 1024 * 1024 + +            with open("file", "w") as f: +                for _ in range((size // 1024) - 5): +                    f.write(data) + +            for i in range(100): +                os.symlink("file", f"file{i}") +            """ +        ).strip() + +        nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1) +        result = nsjail.python3(["-c", code]) +        self.assertEqual(result.returncode, None) +        self.assertEqual( +            result.stdout, "TimeoutError: Exceeded time limit while parsing attachments" +        ) +        self.assertEqual(result.stderr, None) + +    def test_file_write_error(self): +        """Test errors during file write.""" +        result = self.nsjail.python3( +            [""], +            [ +                FileAttachment("dir/test.txt", b"abc"), +                FileAttachment("dir", b"xyz"), +            ], +        ) + +        self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.") +        self.assertEqual(result.stderr, None) +        self.assertEqual(result.returncode, None) +      def test_sigsegv_returns_139(self):  # In honour of Juan.          code = dedent(              """ @@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase):              """          ).strip() -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertEqual(result.returncode, 139)          self.assertEqual(result.stdout, "")          self.assertEqual(result.stderr, None)      def test_null_byte_value_error(self): -        result = self.nsjail.python3("\0") +        # This error only occurs with `-c` mode +        result = self.eval_code("\0")          self.assertEqual(result.returncode, None)          self.assertEqual(result.stdout, "ValueError: embedded null byte")          self.assertEqual(result.stderr, None)      def test_print_bad_unicode_encode_error(self): -        result = self.nsjail.python3("print(chr(56550))") +        result = self.eval_file("print(chr(56550))")          self.assertEqual(result.returncode, 1)          self.assertIn("UnicodeEncodeError", result.stdout)          self.assertEqual(result.stderr, None)      def test_unicode_env_erase_escape_fails(self): -        result = self.nsjail.python3( +        result = self.eval_file(              dedent(                  """                  import os @@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase):                      """                  ).strip() -                result = self.nsjail.python3(code) +                result = self.eval_file(code)                  self.assertEqual(result.returncode, 1)                  self.assertIn("No such file or directory", result.stdout)                  self.assertEqual(result.stderr, None) @@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase):              """          ).strip() -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertEqual(result.returncode, 1)          self.assertIn("Function not implemented", result.stdout)          self.assertEqual(result.stderr, None)      def test_numpy_import(self): -        result = self.nsjail.python3("import numpy") +        result = self.eval_file("import numpy")          self.assertEqual(result.returncode, 0)          self.assertEqual(result.stdout, "")          self.assertEqual(result.stderr, None) @@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase):              """          ).strip() -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertLess(              result.stdout.find(stdout_msg),              result.stdout.find(stderr_msg), @@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase):      def test_stdout_flood_results_in_graceful_sigterm(self):          code = "while True: print('abcdefghij')" -        result = self.nsjail.python3(code) +        result = self.eval_file(code)          self.assertEqual(result.returncode, 143)      def test_large_output_is_truncated(self): @@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase):          self.assertEqual(output, chunk * expected_chunks)      def test_nsjail_args(self): -        args = ("foo", "bar") -        result = self.nsjail.python3("", nsjail_args=args) +        args = ["foo", "bar"] +        result = self.nsjail.python3((), nsjail_args=args)          end = result.args.index("--")          self.assertEqual(result.args[end - len(args) : end], args)      def test_py_args(self): -        args = ("-m", "timeit") -        result = self.nsjail.python3("", py_args=args) - -        self.assertEqual(result.returncode, 0) -        self.assertEqual(result.args[-3:-1], args) +        cases = [ +            # Normal args +            (["-c", "print('hello')"], ["-c", "print('hello')"]), +            # Leading empty strings should be removed +            (["", "-m", "timeit"], ["-m", "timeit"]), +            (["", "", "-m", "timeit"], ["-m", "timeit"]), +            (["", "", "", "-m", "timeit"], ["-m", "timeit"]), +            # Non-leading empty strings should be preserved +            (["-m", "timeit", ""], ["-m", "timeit", ""]), +        ] + +        for args, expected in cases: +            with self.subTest(args=args): +                result = self.nsjail.python3(py_args=args) +                idx = result.args.index("-BSqu") +                self.assertEqual(result.args[idx + 1 :], expected) +                self.assertEqual(result.returncode, 0)  class NsJailArgsTests(unittest.TestCase): diff --git a/tests/test_snekio.py b/tests/test_snekio.py new file mode 100644 index 0000000..8f04429 --- /dev/null +++ b/tests/test_snekio.py @@ -0,0 +1,58 @@ +from unittest import TestCase + +from snekbox import snekio +from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError + + +class SnekIOTests(TestCase): +    def test_safe_path(self) -> None: +        cases = [ +            ("", ""), +            ("foo", "foo"), +            ("foo/bar", "foo/bar"), +            ("foo/bar.ext", "foo/bar.ext"), +        ] + +        for path, expected in cases: +            self.assertEqual(snekio.safe_path(path), expected) + +    def test_safe_path_raise(self): +        cases = [ +            ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"), +            ("/foo", IllegalPathError, "File path '/foo' must be relative"), +        ] + +        for path, error, msg in cases: +            with self.assertRaises(error) as cm: +                snekio.safe_path(path) +            self.assertEqual(str(cm.exception), msg) + +    def test_file_from_dict(self): +        cases = [ +            ({"path": "foo", "content": ""}, FileAttachment("foo", b"")), +            ({"path": "foo"}, FileAttachment("foo", b"")), +            ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")), +            ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")), +        ] + +        for data, expected in cases: +            self.assertEqual(FileAttachment.from_dict(data), expected) + +    def test_file_from_dict_error(self): +        cases = [ +            ( +                {"path": "foo", "content": "9"}, +                ParsingError, +                "Invalid base64 encoding for file 'foo'", +            ), +            ( +                {"path": "var/a.txt", "content": "1="}, +                ParsingError, +                "Invalid base64 encoding for file 'var/a.txt'", +            ), +        ] + +        for data, error, msg in cases: +            with self.assertRaises(error) as cm: +                FileAttachment.from_dict(data) +            self.assertEqual(str(cm.exception), msg) | 
