diff options
-rw-r--r-- | README.md | 21 | ||||
-rw-r--r-- | config/snekbox.cfg | 8 | ||||
-rw-r--r-- | docker-compose.yml | 2 | ||||
-rw-r--r-- | snekbox/__init__.py | 2 | ||||
-rw-r--r-- | snekbox/__main__.py | 2 | ||||
-rw-r--r-- | snekbox/api/resources/eval.py | 82 | ||||
-rw-r--r-- | snekbox/filesystem.py | 88 | ||||
-rw-r--r-- | snekbox/memfs.py | 185 | ||||
-rw-r--r-- | snekbox/nsjail.py | 137 | ||||
-rw-r--r-- | snekbox/process.py | 32 | ||||
-rw-r--r-- | snekbox/snekio.py | 99 | ||||
-rw-r--r-- | snekbox/utils/__init__.py | 4 | ||||
-rw-r--r-- | snekbox/utils/timed.py | 37 | ||||
-rw-r--r-- | tests/api/__init__.py | 4 | ||||
-rw-r--r-- | tests/api/test_eval.py | 106 | ||||
-rw-r--r-- | tests/test_filesystem.py | 143 | ||||
-rw-r--r-- | tests/test_integration.py | 81 | ||||
-rw-r--r-- | tests/test_main.py | 2 | ||||
-rw-r--r-- | tests/test_memfs.py | 65 | ||||
-rw-r--r-- | tests/test_nsjail.py | 170 | ||||
-rw-r--r-- | tests/test_snekio.py | 58 |
21 files changed, 1231 insertions, 97 deletions
@@ -7,6 +7,9 @@ Python sandbox runners for executing code in isolation aka snekbox. +Supports a memory [virtual read/write file system](#virtual-file-system) within the sandbox, +allowing text or binary files to be sent and returned. + A client sends Python code to a snekbox, the snekbox executes the code, and finally the results of the execution are returned to the client. ```mermaid @@ -60,10 +63,26 @@ The main features of the default configuration are: * Memory limit * Process count limit * No networking -* Restricted, read-only filesystem +* Restricted, read-only system filesystem +* Memory-based read-write filesystem mounted as working directory `/home` NsJail is configured through [`snekbox.cfg`]. It contains the exact values for the items listed above. The configuration format is defined by a [protobuf file][7] which can be referred to for documentation. The command-line options of NsJail can also serve as documentation since they closely follow the config file format. +### Memory File System + +On each execution, the host will mount an instance-specific `tmpfs` drive, this is used as a limited read-write folder for the sandboxed code. There is no access to other files or directories on the host container beyond the other read-only mounted system folders. Instance file systems are isolated; it is not possible for sandboxed code to access another instance's writeable directory. + +The following options for the memory file system are configurable as options in [gunicorn.conf.py](config/gunicorn.conf.py) + +* `memfs_instance_size` Size in bytes for the capacity of each instance file system. +* `memfs_home` Path to the home directory within the instance file system. +* `memfs_output` Path to the output directory within the instance file system. +* `files_limit` Maximum number of valid output files to parse. +* `files_timeout` Maximum time in seconds for output file parsing and encoding. +* `files_pattern` Glob pattern to match files within `output`. + +The sandboxed code execution will start with a writeable working directory of `home`. By default, the output folder is also `home`. New files, and uploaded files with a newer last modified time, will be uploaded on completion. + ### Gunicorn [Gunicorn settings] can be found in [`gunicorn.conf.py`]. In the default configuration, the worker count, the bind address, and the WSGI app URI are likely the only things of any interest. Since it uses the default synchronous workers, the [worker count] effectively determines how many concurrent code evaluations can be performed. diff --git a/config/snekbox.cfg b/config/snekbox.cfg index 87c216e..5dd63da 100644 --- a/config/snekbox.cfg +++ b/config/snekbox.cfg @@ -3,7 +3,7 @@ description: "Execute Python" mode: ONCE hostname: "snekbox" -cwd: "/snekbox" +cwd: "/home" time_limit: 6 @@ -16,10 +16,12 @@ envar: "VECLIB_MAXIMUM_THREADS=5" envar: "NUMEXPR_NUM_THREADS=5" envar: "PYTHONPATH=/snekbox/user_base/lib/python3.11/site-packages" envar: "PYTHONIOENCODING=utf-8:strict" +envar: "HOME=home" keep_caps: false rlimit_as: 700 +rlimit_fsize_type: INF clone_newnet: true clone_newuser: true @@ -108,12 +110,12 @@ cgroup_mem_max: 52428800 cgroup_mem_swap_max: 0 cgroup_mem_mount: "/sys/fs/cgroup/memory" -cgroup_pids_max: 5 +cgroup_pids_max: 6 cgroup_pids_mount: "/sys/fs/cgroup/pids" iface_no_lo: true exec_bin { path: "/usr/local/bin/python" - arg: "-Squ" + arg: "-BSqu" } diff --git a/docker-compose.yml b/docker-compose.yml index 9d3ae71..0613abc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ services: image: ghcr.io/python-discord/snekbox${IMAGE_SUFFIX:--venv:dev} pull_policy: never ports: - - 8060:8060 + - "8060:8060" init: true ipc: none tty: true diff --git a/snekbox/__init__.py b/snekbox/__init__.py index bed3692..b45960b 100644 --- a/snekbox/__init__.py +++ b/snekbox/__init__.py @@ -12,7 +12,7 @@ from snekbox.api import SnekAPI # noqa: E402 from snekbox.nsjail import NsJail # noqa: E402 from snekbox.utils.logging import init_logger, init_sentry # noqa: E402 -__all__ = ("NsJail", "SnekAPI") +__all__ = ("NsJail", "SnekAPI", "DEBUG") init_sentry(__version__) init_logger(DEBUG) diff --git a/snekbox/__main__.py b/snekbox/__main__.py index 2382c4c..239f5d5 100644 --- a/snekbox/__main__.py +++ b/snekbox/__main__.py @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: """Evaluate Python code through NsJail.""" args = parse_args() - result = NsJail().python3(args.code, nsjail_args=args.nsjail_args, py_args=args.py_args) + result = NsJail().python3(py_args=[*args.py_args, args.code], nsjail_args=args.nsjail_args) print(result.stdout) if result.returncode != 0: diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py index 1df6c1b..9a53577 100644 --- a/snekbox/api/resources/eval.py +++ b/snekbox/api/resources/eval.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import falcon @@ -7,6 +9,8 @@ from snekbox.nsjail import NsJail __all__ = ("EvalResource",) +from snekbox.snekio import FileAttachment, ParsingError + log = logging.getLogger(__name__) @@ -25,8 +29,26 @@ class EvalResource: "properties": { "input": {"type": "string"}, "args": {"type": "array", "items": {"type": "string"}}, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "path": { + "type": "string", + # Disallow starting with / or containing \0 anywhere + "pattern": r"^(?!/)(?!.*\\0).*$", + }, + "content": {"type": "string"}, + }, + "required": ["path"], + }, + }, }, - "required": ["input"], + "anyOf": [ + {"required": ["input"]}, + {"required": ["args"]}, + ], } def __init__(self, nsjail: NsJail): @@ -38,8 +60,11 @@ class EvalResource: Evaluate Python code and return stdout, stderr, and the return code. A list of arguments for the Python subprocess can be specified as `args`. - Otherwise, the default argument "-c" is used to execute the input code. - The input code is always passed as the last argument to Python. + + If `input` is specified, it will be appended as the last argument to `args`, + and `args` will have a default argument of `"-c"`. + + Either `input` or `args` must be specified. The return codes mostly resemble those of a Unix shell. Some noteworthy cases: @@ -53,15 +78,35 @@ class EvalResource: Request body: >>> { - ... "input": "[i for i in range(1000)]", - ... "args": ["-m", "timeit"] # This is optional + ... "input": "print('Hello')" + ... } + + >>> { + ... "args": ["-c", "print('Hello')"] + ... } + + >>> { + ... "args": ["main.py"], + ... "files": [ + ... { + ... "path": "main.py", + ... "content": "SGVsbG8...=" # Base64 + ... } + ... ] ... } Response format: >>> { - ... "stdout": "10000 loops, best of 5: 23.8 usec per loop\n", - ... "returncode": 0 + ... "stdout": "10000 loops, best of 5: 23.8 usec per loop", + ... "returncode": 0, + ... "files": [ + ... { + ... "path": "output.png", + ... "size": 57344, + ... "content": "eJzzSM3...=" # Base64 + ... } + ... ] ... } Status codes: @@ -69,17 +114,28 @@ class EvalResource: - 200 Successful evaluation; not indicative that the input code itself works - 400 - Input's JSON schema is invalid + Input JSON schema is invalid - 415 Unsupported content type; only application/JSON is supported """ - code = req.media["input"] - args = req.media.get("args", ("-c",)) - + body: dict[str, str | list[str] | list[dict[str, str]]] = req.media + # If `input` is supplied, default `args` to `-c` + if "input" in body: + body.setdefault("args", ["-c"]) + body["args"].append(body["input"]) try: - result = self.nsjail.python3(code, py_args=args) + result = self.nsjail.python3( + py_args=body["args"], + files=[FileAttachment.from_dict(file) for file in body.get("files", [])], + ) + except ParsingError as e: + raise falcon.HTTPBadRequest(title="Request file is invalid", description=str(e)) except Exception: log.exception("An exception occurred while trying to process the request") raise falcon.HTTPInternalServerError - resp.media = {"stdout": result.stdout, "returncode": result.returncode} + resp.media = { + "stdout": result.stdout, + "returncode": result.returncode, + "files": [f.as_dict for f in result.files], + } diff --git a/snekbox/filesystem.py b/snekbox/filesystem.py new file mode 100644 index 0000000..312707c --- /dev/null +++ b/snekbox/filesystem.py @@ -0,0 +1,88 @@ +"""Mounts and unmounts filesystems.""" +from __future__ import annotations + +import ctypes +import os +from ctypes.util import find_library +from enum import IntEnum +from pathlib import Path + +__all__ = ("mount", "unmount", "Size", "UnmountFlags") + +libc = ctypes.CDLL(find_library("c"), use_errno=True) +libc.mount.argtypes = ( + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_ulong, + ctypes.c_char_p, +) +libc.umount2.argtypes = (ctypes.c_char_p, ctypes.c_int) + + +class Size(IntEnum): + """Size multipliers for bytes.""" + + KiB = 1024 + MiB = 1024**2 + GiB = 1024**3 + TiB = 1024**4 + + +class UnmountFlags(IntEnum): + """Flags for umount2.""" + + MNT_FORCE = 1 + MNT_DETACH = 2 + MNT_EXPIRE = 4 + UMOUNT_NOFOLLOW = 8 + + +def mount(source: Path | str, target: Path | str, fs: str, **options: str | int) -> None: + """ + Mount a filesystem. + + https://man7.org/linux/man-pages/man8/mount.8.html + + Args: + source: Source directory or device. + target: Target directory. + fs: Filesystem type. + **options: Mount options. + + Raises: + OSError: On any mount error. + """ + if Path(target).is_mount(): + raise OSError(f"{target} is already a mount point") + + kwargs = " ".join(f"{key}={value}" for key, value in options.items()) + + result: int = libc.mount( + str(source).encode(), str(target).encode(), fs.encode(), 0, kwargs.encode() + ) + if result < 0: + errno = ctypes.get_errno() + raise OSError(errno, f"Error mounting {target}: {os.strerror(errno)}") + + +def unmount(target: Path | str, flags: UnmountFlags | int = UnmountFlags.MNT_DETACH) -> None: + """ + Unmount a filesystem. + + https://man7.org/linux/man-pages/man2/umount.2.html + + Args: + target: Target directory. + flags: Unmount flags. + + Raises: + OSError: On any unmount error. + """ + if not Path(target).is_mount(): + raise OSError(f"{target} is not a mount point") + + result: int = libc.umount2(str(target).encode(), int(flags)) + if result < 0: + errno = ctypes.get_errno() + raise OSError(errno, f"Error unmounting {target}: {os.strerror(errno)}") diff --git a/snekbox/memfs.py b/snekbox/memfs.py new file mode 100644 index 0000000..f32fed1 --- /dev/null +++ b/snekbox/memfs.py @@ -0,0 +1,185 @@ +"""Memory filesystem for snekbox.""" +from __future__ import annotations + +import logging +import warnings +import weakref +from collections.abc import Generator +from contextlib import suppress +from pathlib import Path +from types import TracebackType +from typing import Type +from uuid import uuid4 + +from snekbox.filesystem import mount, unmount +from snekbox.snekio import FileAttachment + +log = logging.getLogger(__name__) + +__all__ = ("MemFS",) + + +class MemFS: + """An in-memory temporary file system.""" + + def __init__( + self, + instance_size: int, + root_dir: str | Path = "/memfs", + home: str = "home", + output: str = "home", + ) -> None: + """ + Initialize an in-memory temporary file system. + + Examples: + >>> with MemFS(1024) as memfs: + ... (memfs.home / "test.txt").write_text("Hello") + + Args: + instance_size: Size limit of each tmpfs instance in bytes. + root_dir: Root directory to mount instances in. + home: Name of the home directory. + output: Name of the output directory within home. If empty, uses home. + """ + self.instance_size = instance_size + self.root_dir = Path(root_dir) + self.root_dir.mkdir(exist_ok=True, parents=True) + self._home_name = home + self._output_name = output + + for _ in range(10): + name = str(uuid4()) + try: + self.path = self.root_dir / name + self.path.mkdir() + mount("", self.path, "tmpfs", size=self.instance_size) + break + except OSError: + continue + else: + raise RuntimeError("Failed to generate a unique MemFS name in 10 attempts") + + self.mkdir(self.home) + self.mkdir(self.output) + + self._finalizer = weakref.finalize( + self, + self._cleanup, + self.path, + warn_message=f"Implicitly cleaning up {self!r}", + ) + + @classmethod + def _cleanup(cls, path: Path, warn_message: str): + """Implicit cleanup of the MemFS.""" + with suppress(OSError): + unmount(path) + path.rmdir() + warnings.warn(warn_message, ResourceWarning) + + def cleanup(self) -> None: + """Unmount the tempfs and remove the directory.""" + if self._finalizer.detach() or self.path.exists(): + unmount(self.path) + self.path.rmdir() + + @property + def name(self) -> str: + """Name of the temp dir.""" + return self.path.name + + @property + def home(self) -> Path: + """Path to home directory.""" + return self.path / self._home_name + + @property + def output(self) -> Path: + """Path to output directory.""" + return self.path / self._output_name + + def __enter__(self) -> MemFS: + return self + + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.cleanup() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.path}>" + + def mkdir(self, path: Path | str, chmod: int = 0o777) -> Path: + """Create a directory in the tempdir.""" + folder = Path(self.path, path) + folder.mkdir(parents=True, exist_ok=True) + folder.chmod(chmod) + return folder + + def files( + self, + limit: int, + pattern: str = "**/*", + exclude_files: dict[Path, float] | None = None, + ) -> Generator[FileAttachment, None, None]: + """ + Yields FileAttachments for files found in the output directory. + + Args: + limit: The maximum number of files to parse. + pattern: The glob pattern to match files against. + exclude_files: A dict of Paths and last modified times. + Files will be excluded if their last modified time + is equal to the provided value. + """ + count = 0 + for file in self.output.rglob(pattern): + if exclude_files and (orig_time := exclude_files.get(file)): + new_time = file.stat().st_mtime + log.info(f"Checking {file.name} ({orig_time=}, {new_time=})") + if file.stat().st_mtime == orig_time: + log.info(f"Skipping {file.name!r} as it has not been modified") + continue + + if count > limit: + log.info(f"Max attachments {limit} reached, skipping remaining files") + break + + if file.is_file(): + count += 1 + log.info(f"Found valid file for upload {file.name!r}") + yield FileAttachment.from_path(file, relative_to=self.output) + + def files_list( + self, + limit: int, + pattern: str, + exclude_files: dict[Path, float] | None = None, + preload_dict: bool = False, + ) -> list[FileAttachment]: + """ + Return a sorted list of file paths within the output directory. + + Args: + limit: The maximum number of files to parse. + pattern: The glob pattern to match files against. + exclude_files: A dict of Paths and last modified times. + Files will be excluded if their last modified time + is equal to the provided value. + preload_dict: Whether to preload as_dict property data. + Returns: + List of FileAttachments sorted lexically by path name. + """ + res = sorted( + self.files(limit=limit, pattern=pattern, exclude_files=exclude_files), + key=lambda f: f.path, + ) + if preload_dict: + for file in res: + # Loads the cached property as attribute + _ = file.as_dict + return res diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 63afdef..f014850 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -2,26 +2,43 @@ import logging import re import subprocess import sys -import textwrap -from subprocess import CompletedProcess +from collections.abc import Generator +from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Iterable +from typing import Iterable, TypeVar from google.protobuf import text_format from snekbox import DEBUG, utils from snekbox.config_pb2 import NsJailConfig +from snekbox.filesystem import Size +from snekbox.memfs import MemFS +from snekbox.process import EvalResult +from snekbox.snekio import FileAttachment +from snekbox.utils.timed import timed __all__ = ("NsJail",) log = logging.getLogger(__name__) +_T = TypeVar("_T") + # [level][timestamp][PID]? function_signature:line_no? message LOG_PATTERN = re.compile( r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)" ) +def iter_lstrip(iterable: Iterable[_T]) -> Generator[_T, None, None]: + """Remove leading falsy objects from an iterable.""" + it = iter(iterable) + for item in it: + if item: + yield item + break + yield from it + + class NsJail: """ Core Snekbox functionality, providing safe execution of Python code. @@ -35,12 +52,41 @@ class NsJail: config_path: str = "./config/snekbox.cfg", max_output_size: int = 1_000_000, read_chunk_size: int = 10_000, + memfs_instance_size: int = 48 * Size.MiB, + memfs_home: str = "home", + memfs_output: str = "home", + files_limit: int | None = 100, + files_timeout: float | None = 8, + files_pattern: str = "**/[!_]*", ): + """ + Initialize NsJail. + + Args: + nsjail_path: Path to the NsJail binary. + config_path: Path to the NsJail configuration file. + max_output_size: Maximum size of the output in bytes. + read_chunk_size: Size of the read buffer in bytes. + memfs_instance_size: Size of the tmpfs instance in bytes. + memfs_home: Name of the mounted home directory. + memfs_output: Name of the output directory within home, + can be empty to use home as output. + files_limit: Maximum number of output files to parse. + files_timeout: Maximum time in seconds to wait for output files to be read. + files_pattern: Pattern to match files to attach within the output directory. + """ self.nsjail_path = nsjail_path self.config_path = config_path self.max_output_size = max_output_size self.read_chunk_size = read_chunk_size + self.memfs_instance_size = memfs_instance_size + self.memfs_home = memfs_home + self.memfs_output = memfs_output + self.files_limit = files_limit + self.files_timeout = files_timeout + self.files_pattern = files_pattern + self.config = self._read_config(config_path) self.cgroup_version = utils.cgroup.init(self.config) self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version) @@ -129,16 +175,18 @@ class NsJail: return "".join(output) def python3( - self, code: str, *, nsjail_args: Iterable[str] = (), py_args: Iterable[str] = ("-c",) - ) -> CompletedProcess: + self, + py_args: Iterable[str], + files: Iterable[FileAttachment] = (), + nsjail_args: Iterable[str] = (), + ) -> EvalResult: """ Execute Python 3 code in an isolated environment and return the completed process. - The `nsjail_args` passed will be used to override the values in the NsJail config. - These arguments are only options for NsJail; they do not affect Python's arguments. - - `py_args` are arguments to pass to the Python subprocess before the code, - which is the last argument. By default, it's "-c", which executes the code given. + Args: + py_args: Arguments to pass to Python. + files: FileAttachments to write to the sandbox prior to running Python. + nsjail_args: Overrides for the NsJail configuration. """ if self.cgroup_version == 2: nsjail_args = ("--use_cgroupv2", *nsjail_args) @@ -152,8 +200,19 @@ class NsJail: *nsjail_args, ) - with NamedTemporaryFile() as nsj_log: - args = ( + with NamedTemporaryFile() as nsj_log, MemFS( + instance_size=self.memfs_instance_size, + home=self.memfs_home, + output=self.memfs_output, + ) as fs: + nsjail_args = ( + # Mount `home` with Read/Write access + "--bindmount", + f"{fs.home}:home", + *nsjail_args, + ) + + args = [ self.nsjail_path, "--config", self.config_path, @@ -163,13 +222,30 @@ class NsJail: "--", self.config.exec_bin.path, *self.config.exec_bin.arg, - *py_args, - code, - ) + # Filter out empty strings at start of py_args + # (causes issues with python cli) + *iter_lstrip(py_args), + ] + + # Write provided files if any + files_written: dict[Path, float] = {} + for file in files: + try: + f_path = file.save_to(fs.home) + # Allow file to be writable + f_path.chmod(0o777) + # Save the written at time to later check if it was modified + files_written[f_path] = f_path.stat().st_mtime + log.info(f"Created file at {(fs.home / file.path)!r}.") + except OSError as e: + log.info(f"Failed to create file at {(fs.home / file.path)!r}.", exc_info=e) + return EvalResult( + args, None, f"{e.__class__.__name__}: Failed to create file '{file.path}'." + ) msg = "Executing code..." if DEBUG: - msg = f"{msg[:-3]}:\n{textwrap.indent(code, ' ')}\nWith the arguments {args}." + msg = f"{msg[:-3]} with the arguments {args}." log.info(msg) try: @@ -177,23 +253,36 @@ class NsJail: args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) except ValueError: - return CompletedProcess(args, None, "ValueError: embedded null byte", None) + return EvalResult(args, None, "ValueError: embedded null byte") try: output = self._consume_stdout(nsjail) except UnicodeDecodeError: - return CompletedProcess( - args, - None, - "UnicodeDecodeError: invalid Unicode in output pipe", - None, - ) + return EvalResult(args, None, "UnicodeDecodeError: invalid Unicode in output pipe") # When you send signal `N` to a subprocess to terminate it using Popen, it # will return `-N` as its exit code. As we normally get `N + 128` back, we # convert negative exit codes to the `N + 128` form. returncode = -nsjail.returncode + 128 if nsjail.returncode < 0 else nsjail.returncode + # Parse attachments with time limit + try: + attachments = timed( + MemFS.files_list, + (fs, self.files_limit, self.files_pattern), + { + "preload_dict": True, + "exclude_files": files_written, + }, + timeout=self.files_timeout, + ) + log.info(f"Found {len(attachments)} files.") + except TimeoutError as e: + log.info(f"Exceeded time limit while parsing attachments: {e}") + return EvalResult( + args, None, "TimeoutError: Exceeded time limit while parsing attachments" + ) + log_lines = nsj_log.read().decode("utf-8").splitlines() if not log_lines and returncode == 255: # NsJail probably failed to parse arguments so log output will still be in stdout @@ -203,4 +292,4 @@ class NsJail: log.info(f"nsjail return code: {returncode}") - return CompletedProcess(args, returncode, output, None) + return EvalResult(args, returncode, output, files=attachments) diff --git a/snekbox/process.py b/snekbox/process.py new file mode 100644 index 0000000..552b91a --- /dev/null +++ b/snekbox/process.py @@ -0,0 +1,32 @@ +"""Utilities for process management.""" +from collections.abc import Sequence +from os import PathLike +from subprocess import CompletedProcess +from typing import TypeVar + +from snekbox.snekio import FileAttachment + +_T = TypeVar("_T") +ArgType = ( + str + | bytes + | PathLike[str] + | PathLike[bytes] + | Sequence[str | bytes | PathLike[str] | PathLike[bytes]] +) + + +class EvalResult(CompletedProcess[_T]): + """An evaluation job that has finished running.""" + + def __init__( + self, + args: ArgType, + returncode: int | None, + stdout: _T | None = None, + stderr: _T | None = None, + files: list[FileAttachment] | None = None, + ) -> None: + """Create an evaluation result.""" + super().__init__(args, returncode, stdout, stderr) + self.files: list[FileAttachment] = files or [] diff --git a/snekbox/snekio.py b/snekbox/snekio.py new file mode 100644 index 0000000..821f057 --- /dev/null +++ b/snekbox/snekio.py @@ -0,0 +1,99 @@ +"""I/O Operations for sending / receiving files from the sandbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path + + +def safe_path(path: str) -> str: + """ + Return `path` if there are no security issues. + + Raises: + IllegalPathError: Raised on any path rule violation. + """ + # Disallow absolute paths + if Path(path).is_absolute(): + raise IllegalPathError(f"File path '{path}' must be relative") + + # Disallow traversal beyond root + try: + test_root = Path("/home") + Path(test_root).joinpath(path).resolve().relative_to(test_root.resolve()) + except ValueError: + raise IllegalPathError(f"File path '{path}' may not traverse beyond root") + + return path + + +class ParsingError(ValueError): + """Raised when an incoming content cannot be parsed.""" + + +class IllegalPathError(ParsingError): + """Raised when a request file has an illegal path.""" + + +@dataclass(frozen=True) +class FileAttachment: + """A file attachment.""" + + path: str + content: bytes + + def __repr__(self) -> str: + path = f"{self.path[:30]}..." if len(self.path) > 30 else self.path + content = f"{self.content[:15]}..." if len(self.content) > 15 else self.content + return f"{self.__class__.__name__}(path={path!r}, content={content!r})" + + @classmethod + def from_dict(cls, data: dict[str, str]) -> FileAttachment: + """ + Convert a dict to an attachment. + + Raises: + ParsingError: Raised when the dict has invalid base64 `content`. + """ + path = safe_path(data["path"]) + try: + content = b64decode(data.get("content", "")) + except (TypeError, ValueError) as e: + raise ParsingError(f"Invalid base64 encoding for file '{path}'") from e + return cls(path, content) + + @classmethod + def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment: + """ + Create an attachment from a file path. + + Args: + file: The file to attach. + relative_to: The root for the path name. + """ + path = file.relative_to(relative_to) if relative_to else file + return cls(str(path), file.read_bytes()) + + @property + def size(self) -> int: + """Size of the attachment.""" + return len(self.content) + + def save_to(self, directory: Path | str) -> Path: + """Write the attachment to a file in `directory`. Return a Path of the file.""" + file = Path(directory, self.path) + # Create directories if they don't exist + file.parent.mkdir(parents=True, exist_ok=True) + file.write_bytes(self.content) + return file + + @cached_property + def as_dict(self) -> dict[str, str | int]: + """Convert the attachment to a dict.""" + content = b64encode(self.content).decode("ascii") + return { + "path": self.path, + "size": self.size, + "content": content, + } diff --git a/snekbox/utils/__init__.py b/snekbox/utils/__init__.py index 6d6bc32..010fa65 100644 --- a/snekbox/utils/__init__.py +++ b/snekbox/utils/__init__.py @@ -1,3 +1,3 @@ -from . import cgroup, logging, swap +from . import cgroup, logging, swap, timed -__all__ = ("cgroup", "logging", "swap") +__all__ = ("cgroup", "logging", "swap", "timed") diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py new file mode 100644 index 0000000..02388ff --- /dev/null +++ b/snekbox/utils/timed.py @@ -0,0 +1,37 @@ +"""Calling functions with time limits.""" +import multiprocessing +from collections.abc import Callable, Iterable, Mapping +from typing import Any, TypeVar + +_T = TypeVar("_T") +_V = TypeVar("_V") + +__all__ = ("timed",) + + +def timed( + func: Callable[[_T], _V], + args: Iterable = (), + kwds: Mapping[str, Any] | None = None, + timeout: float | None = None, +) -> _V: + """ + Call a function with a time limit. + + Args: + func: Function to call. + args: Arguments for function. + kwds: Keyword arguments for function. + timeout: Timeout limit in seconds. + + Raises: + TimeoutError: If the function call takes longer than `timeout` seconds. + """ + if kwds is None: + kwds = {} + with multiprocessing.Pool(1) as pool: + result = pool.apply_async(func, args, kwds) + try: + return result.get(timeout) + except multiprocessing.TimeoutError as e: + raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 0e6e422..5f20faf 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -1,10 +1,10 @@ import logging -from subprocess import CompletedProcess from unittest import mock from falcon import testing from snekbox.api import SnekAPI +from snekbox.process import EvalResult class SnekAPITestCase(testing.TestCase): @@ -13,7 +13,7 @@ class SnekAPITestCase(testing.TestCase): self.patcher = mock.patch("snekbox.api.snekapi.NsJail", autospec=True) self.mock_nsjail = self.patcher.start() - self.mock_nsjail.return_value.python3.return_value = CompletedProcess( + self.mock_nsjail.return_value.python3.return_value = EvalResult( args=[], returncode=0, stdout="output", stderr="error" ) self.addCleanup(self.patcher.stop) diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py index 976970e..37f90e7 100644 --- a/tests/api/test_eval.py +++ b/tests/api/test_eval.py @@ -5,12 +5,19 @@ class TestEvalResource(SnekAPITestCase): PATH = "/eval" def test_post_valid_200(self): - body = {"input": "foo"} - result = self.simulate_post(self.PATH, json=body) - - self.assertEqual(result.status_code, 200) - self.assertEqual("output", result.json["stdout"]) - self.assertEqual(0, result.json["returncode"]) + cases = [ + {"args": ["-c", "print('output')"]}, + {"input": "print('hello')"}, + {"input": "print('hello')", "args": ["-c"]}, + {"input": "print('hello')", "args": [""]}, + {"input": "pass", "args": ["-m", "timeit"]}, + ] + for body in cases: + with self.subTest(): + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 200) + self.assertEqual("output", result.json["stdout"]) + self.assertEqual(0, result.json["returncode"]) def test_post_invalid_schema_400(self): body = {"stuff": "foo"} @@ -20,27 +27,100 @@ class TestEvalResource(SnekAPITestCase): expected = { "title": "Request data failed validation", - "description": "'input' is a required property", + "description": "{'stuff': 'foo'} is not valid under any of the given schemas", } self.assertEqual(expected, result.json) def test_post_invalid_data_400(self): - bodies = ({"input": 400}, {"input": "", "args": [400]}) - - for body in bodies: + bodies = ({"args": 400}, {"args": [], "files": [215]}) + expects = ["400 is not of type 'array'", "215 is not of type 'object'"] + for body, expected in zip(bodies, expects): with self.subTest(): result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 400) - expected = { + expected_json = { "title": "Request data failed validation", - "description": "400 is not of type 'string'", + "description": expected, + } + self.assertEqual(expected_json, result.json) + + def test_files_path(self): + """Normal paths should work with 200.""" + test_paths = [ + "file.txt", + "./0.jpg", + "path/to/file", + "folder/../hm", + "folder/./to/./somewhere", + "traversal/but/../not/beyond/../root", + r"backslash\\okay", + r"backslash\okay", + "numbers/0123456789", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 200) + self.assertEqual("output", result.json["stdout"]) + self.assertEqual(0, result.json["returncode"]) + + def test_files_illegal_path_traversal(self): + """Traversal beyond root should be denied with 400 error.""" + test_paths = [ + "../secrets", + "../../dir", + "dir/../../secrets", + "dir/var/../../../file", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + expected = { + "title": "Request file is invalid", + "description": f"File path '{path}' may not traverse beyond root", } - self.assertEqual(expected, result.json) + def test_files_illegal_path_absolute(self): + """Absolute file paths should 400-error at json schema validation stage.""" + test_paths = [ + "/", + "/etc", + "/etc/vars/secrets", + "/absolute", + "/file.bin", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + self.assertEqual("Request data failed validation", result.json["title"]) + self.assertIn("does not match", result.json["description"]) + + def test_files_illegal_path_null_byte(self): + """Paths containing \0 should 400-error at json schema validation stage.""" + test_paths = [ + r"etc/passwd\0", + r"a\0b", + r"\0", + r"\\0", + r"var/\0/path", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + self.assertEqual("Request data failed validation", result.json["title"]) + self.assertIn("does not match", result.json["description"]) + def test_post_invalid_content_type_415(self): body = "{'input': 'foo'}" headers = {"Content-Type": "application/xml"} diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000..e4d081f --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,143 @@ +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, suppress +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase +from uuid import uuid4 + +from snekbox.filesystem import UnmountFlags, mount, unmount + + +class LibMountTests(TestCase): + temp_dir: TemporaryDirectory + + @classmethod + def setUpClass(cls): + cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests") + + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + + @contextmanager + def get_mount(self): + """Yield a valid mount point and unmount after context.""" + path = Path(self.temp_dir.name, str(uuid4())) + path.mkdir() + try: + mount(source="", target=path, fs="tmpfs") + yield path + finally: + with suppress(OSError): + unmount(path) + + def test_mount(self): + """Test normal mounting.""" + with self.get_mount() as path: + self.assertTrue(path.is_mount()) + self.assertTrue(path.exists()) + self.assertFalse(path.is_mount()) + # Unmounting should not remove the original folder + self.assertTrue(path.exists()) + + def test_mount_errors(self): + """Test invalid mount errors.""" + cases = [ + (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"), + (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"), + ( + dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"), + OSError, + "Invalid argument", + ), + ] + for case, err, msg in cases: + with self.subTest(case=case): + with self.assertRaises(err) as cm: + mount(**case) + self.assertIn(msg, str(cm.exception)) + + def test_mount_duplicate(self): + """Test attempted mount after mounted.""" + path = Path(self.temp_dir.name, str(uuid4())) + path.mkdir() + try: + mount(source="", target=path, fs="tmpfs") + with self.assertRaises(OSError) as cm: + mount(source="", target=path, fs="tmpfs") + self.assertIn("already a mount point", str(cm.exception)) + finally: + unmount(target=path) + + def test_unmount_flags(self): + """Test unmount flags.""" + flags = [ + UnmountFlags.MNT_FORCE, + UnmountFlags.MNT_DETACH, + UnmountFlags.UMOUNT_NOFOLLOW, + ] + for flag in flags: + with self.subTest(flag=flag), self.get_mount() as path: + self.assertTrue(path.is_mount()) + unmount(path, flag) + self.assertFalse(path.is_mount()) + + def test_unmount_flags_expire(self): + """Test unmount MNT_EXPIRE behavior.""" + with self.get_mount() as path: + with self.assertRaises(BlockingIOError): + unmount(path, UnmountFlags.MNT_EXPIRE) + + def test_unmount_errors(self): + """Test invalid unmount errors.""" + cases = [ + (dict(target="not/exist"), OSError, "is not a mount point"), + (dict(target=Path("not/exist")), OSError, "is not a mount point"), + ] + for case, err, msg in cases: + with self.subTest(case=case): + with self.assertRaises(err) as cm: + unmount(**case) + self.assertIn(msg, str(cm.exception)) + + def test_unmount_invalid_args(self): + """Test invalid unmount invalid flag.""" + with self.get_mount() as path: + with self.assertRaises(OSError) as cm: + unmount(path, 251) + self.assertIn("Invalid argument", str(cm.exception)) + + def test_threading(self): + """Test concurrent mounting works in multi-thread environments.""" + paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)] + + for path in paths: + path.mkdir() + self.assertFalse(path.is_mount()) + + try: + with ThreadPoolExecutor() as pool: + res = list( + pool.map( + mount, + [""] * len(paths), + paths, + ["tmpfs"] * len(paths), + ) + ) + self.assertEqual(len(res), len(paths)) + + for path in paths: + with self.subTest(path=path): + self.assertTrue(path.is_mount()) + + unmounts = list(pool.map(unmount, paths)) + self.assertEqual(len(unmounts), len(paths)) + + for path in paths: + with self.subTest(path=path): + self.assertFalse(path.is_mount()) + finally: + with suppress(OSError): + for path in paths: + unmount(path) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7c5db2b..91b01e6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,14 +1,25 @@ import json import unittest import urllib.request +from base64 import b64encode from multiprocessing.dummy import Pool +from textwrap import dedent from tests.gunicorn_utils import run_gunicorn -def run_code_in_snekbox(code: str) -> tuple[str, int]: - body = {"input": code} - json_data = json.dumps(body).encode("utf-8") +def b64encode_code(data: str): + data = dedent(data).strip() + return b64encode(data.encode()).decode("ascii") + + +def snekbox_run_code(code: str) -> tuple[str, int]: + body = {"args": ["-c", code]} + return snekbox_request(body) + + +def snekbox_request(content: dict) -> tuple[str, int]: + json_data = json.dumps(content).encode("utf-8") req = urllib.request.Request("http://localhost:8060/eval") req.add_header("Content-Type", "application/json; charset=utf-8") @@ -34,9 +45,71 @@ class IntegrationTests(unittest.TestCase): args = [code] * processes with Pool(processes) as p: - results = p.map(run_code_in_snekbox, args) + results = p.map(snekbox_run_code, args) responses, statuses = zip(*results) self.assertTrue(all(status == 200 for status in statuses)) self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses)) + + def test_eval(self): + """Test normal eval requests without files.""" + with run_gunicorn(): + cases = [ + ({"input": "print('Hello')"}, "Hello\n"), + ({"args": ["-c", "print('abc12')"]}, "abc12\n"), + ] + for body, expected in cases: + with self.subTest(body=body): + response, status = snekbox_request(body) + self.assertEqual(status, 200) + self.assertEqual(json.loads(response)["stdout"], expected) + + def test_files_send_receive(self): + """Test sending and receiving files to snekbox.""" + with run_gunicorn(): + request = { + "args": ["main.py"], + "files": [ + { + "path": "main.py", + "content": b64encode_code( + """ + from pathlib import Path + from mod import lib + print(lib.var) + + with open('test.txt', 'w') as f: + f.write('test 1') + + Path('dir').mkdir() + Path('dir/test2.txt').write_text('test 2') + """ + ), + }, + {"path": "mod/__init__.py"}, + {"path": "mod/lib.py", "content": b64encode_code("var = 'hello'")}, + ], + } + + expected = { + "stdout": "hello\n", + "returncode": 0, + "files": [ + { + "path": "dir/test2.txt", + "size": len("test 2"), + "content": b64encode_code("test 2"), + }, + { + "path": "test.txt", + "size": len("test 1"), + "content": b64encode_code("test 1"), + }, + ], + } + + response, status = snekbox_request(request) + + self.assertEqual(200, status) + self.assertEqual(expected, json.loads(response)) diff --git a/tests/test_main.py b/tests/test_main.py index 77b3130..24c067c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase): @patch("sys.argv", ["", "import sys; sys.exit(22)"]) def test_main_exits_with_returncode(self): - """Should exit with the subprocess's returncode if it's non-zero.""" + """Should exit with the subprocess returncode if it's non-zero.""" with self.assertRaises(SystemExit) as cm: snekbox_main.main() diff --git a/tests/test_memfs.py b/tests/test_memfs.py new file mode 100644 index 0000000..0555726 --- /dev/null +++ b/tests/test_memfs.py @@ -0,0 +1,65 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack +from unittest import TestCase, mock +from uuid import uuid4 + +from snekbox.memfs import MemFS + +UUID_TEST = uuid4() + + +class MemFSTests(TestCase): + def setUp(self): + super().setUp() + self.logger = logging.getLogger("snekbox.memfs") + self.logger.setLevel(logging.WARNING) + + @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST) + def test_assignment_thread_safe(self): + """Test concurrent mounting works in multi-thread environments.""" + # Concurrently create MemFS in threads, check only 1 can be created + # Others should result in RuntimeError + with ExitStack() as stack: + with ThreadPoolExecutor() as executor: + memfs: MemFS | None = None + # Each future uses enter_context to ensure __exit__ on test exception + futures = [ + executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8) + ] + for future in futures: + # We should have exactly one result and all others RuntimeErrors + if err := future.exception(): + self.assertIsInstance(err, RuntimeError) + else: + self.assertIsNone(memfs) + memfs = future.result() + + # Original memfs should still exist afterwards + self.assertIsInstance(memfs, MemFS) + self.assertTrue(memfs.path.is_mount()) + + def test_cleanup(self): + """Test explicit cleanup.""" + memfs = MemFS(10) + path = memfs.path + self.assertTrue(path.is_mount()) + memfs.cleanup() + self.assertFalse(path.exists()) + + def test_context_cleanup(self): + """Context __exit__ should trigger cleanup.""" + with MemFS(10) as memfs: + path = memfs.path + self.assertTrue(path.is_mount()) + self.assertFalse(path.exists()) + + def test_implicit_cleanup(self): + """Test implicit _cleanup triggered by GC.""" + memfs = MemFS(10) + path = memfs.path + self.assertTrue(path.is_mount()) + # Catch the warning about implicit cleanup + with self.assertWarns(ResourceWarning): + del memfs + self.assertFalse(path.exists()) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 6f6e2a7..cad79f3 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -9,28 +9,40 @@ from itertools import product from pathlib import Path from textwrap import dedent +from snekbox.filesystem import Size from snekbox.nsjail import NsJail +from snekbox.snekio import FileAttachment class NsJailTests(unittest.TestCase): def setUp(self): super().setUp() - self.nsjail = NsJail() + # Specify lower limits for unit tests to complete within time limits + self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB) self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) + def eval_code(self, code: str): + return self.nsjail.python3(["-c", code]) + + def eval_file(self, code: str, name: str = "test.py", **kwargs): + file = FileAttachment(name, code.encode()) + return self.nsjail.python3([name], [file], **kwargs) + def test_print_returns_0(self): - result = self.nsjail.python3("print('test')") - self.assertEqual(result.returncode, 0) - self.assertEqual(result.stdout, "test\n") - self.assertEqual(result.stderr, None) + for fn in (self.eval_code, self.eval_file): + with self.subTest(fn.__name__): + result = fn("print('test')") + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "test\n") + self.assertEqual(result.stderr, None) def test_timeout_returns_137(self): code = "while True: pass" with self.assertLogs(self.logger) as log: - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 137) self.assertEqual(result.stdout, "") @@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase): # Add a kilobyte just to be safe. code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" - result = self.nsjail.python3(code) - self.assertEqual(result.returncode, 137) + result = self.eval_file(code) self.assertEqual(result.stdout, "") + self.assertEqual(result.returncode, 137) + self.assertEqual(result.stderr, None) + + def test_multi_files(self): + files = [ + FileAttachment("main.py", "import lib; print(lib.x)".encode()), + FileAttachment("lib.py", "x = 'hello'".encode()), + ] + + result = self.nsjail.python3(["main.py"], files) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") self.assertEqual(result.stderr, None) def test_subprocess_resource_unavailable(self): + max_pids = self.nsjail.config.cgroup_pids_max code = dedent( - """ + f""" import subprocess - # Max PIDs is 5. - for _ in range(6): + # Should fail at n (max PIDs) since the caller python process counts as well + for _ in range({max_pids}): print(subprocess.Popen( [ '/usr/local/bin/python3', @@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) + # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n" + expected = "\n".join(map(str, range(2, max_pids + 1))) + self.assertIn(expected, result.stdout) self.assertEqual(result.stderr, None) def test_multiprocess_resource_limits(self): @@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase): """ ) - result = self.nsjail.python3(code) + result = self.eval_file(code) exit_codes = result.stdout.strip().split() self.assertIn("-9", exit_codes) @@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Read-only file system", result.stdout) self.assertEqual(result.stderr, None) + def test_write(self): + code = dedent( + """ + from pathlib import Path + with open('test.txt', 'w') as f: + f.write('hello') + print(Path('test.txt').read_text()) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") + self.assertEqual(result.stderr, None) + + def test_write_exceed_space(self): + code = dedent( + f""" + size = {self.nsjail.memfs_instance_size} // 2048 + with open('f.bin', 'wb') as f: + for i in range(size): + f.write(b'1' * 2048) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 1) + self.assertIn("No space left on device", result.stdout) + self.assertEqual(result.stderr, None) + def test_forkbomb_resource_unavailable(self): code = dedent( """ @@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) + def test_file_parsing_timeout(self): + code = dedent( + """ + import os + data = "a" * 1024 + size = 32 * 1024 * 1024 + + with open("file", "w") as f: + for _ in range((size // 1024) - 5): + f.write(data) + + for i in range(100): + os.symlink("file", f"file{i}") + """ + ).strip() + + nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1) + result = nsjail.python3(["-c", code]) + self.assertEqual(result.returncode, None) + self.assertEqual( + result.stdout, "TimeoutError: Exceeded time limit while parsing attachments" + ) + self.assertEqual(result.stderr, None) + + def test_file_write_error(self): + """Test errors during file write.""" + result = self.nsjail.python3( + [""], + [ + FileAttachment("dir/test.txt", b"abc"), + FileAttachment("dir", b"xyz"), + ], + ) + + self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.") + self.assertEqual(result.stderr, None) + self.assertEqual(result.returncode, None) + def test_sigsegv_returns_139(self): # In honour of Juan. code = dedent( """ @@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 139) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) def test_null_byte_value_error(self): - result = self.nsjail.python3("\0") + # This error only occurs with `-c` mode + result = self.eval_code("\0") self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "ValueError: embedded null byte") self.assertEqual(result.stderr, None) def test_print_bad_unicode_encode_error(self): - result = self.nsjail.python3("print(chr(56550))") + result = self.eval_file("print(chr(56550))") self.assertEqual(result.returncode, 1) self.assertIn("UnicodeEncodeError", result.stdout) self.assertEqual(result.stderr, None) def test_unicode_env_erase_escape_fails(self): - result = self.nsjail.python3( + result = self.eval_file( dedent( """ import os @@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No such file or directory", result.stdout) self.assertEqual(result.stderr, None) @@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Function not implemented", result.stdout) self.assertEqual(result.stderr, None) def test_numpy_import(self): - result = self.nsjail.python3("import numpy") + result = self.eval_file("import numpy") self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertLess( result.stdout.find(stdout_msg), result.stdout.find(stderr_msg), @@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase): def test_stdout_flood_results_in_graceful_sigterm(self): code = "while True: print('abcdefghij')" - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): @@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase): self.assertEqual(output, chunk * expected_chunks) def test_nsjail_args(self): - args = ("foo", "bar") - result = self.nsjail.python3("", nsjail_args=args) + args = ["foo", "bar"] + result = self.nsjail.python3((), nsjail_args=args) end = result.args.index("--") self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): - args = ("-m", "timeit") - result = self.nsjail.python3("", py_args=args) - - self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-3:-1], args) + cases = [ + # Normal args + (["-c", "print('hello')"], ["-c", "print('hello')"]), + # Leading empty strings should be removed + (["", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "", "-m", "timeit"], ["-m", "timeit"]), + # Non-leading empty strings should be preserved + (["-m", "timeit", ""], ["-m", "timeit", ""]), + ] + + for args, expected in cases: + with self.subTest(args=args): + result = self.nsjail.python3(py_args=args) + idx = result.args.index("-BSqu") + self.assertEqual(result.args[idx + 1 :], expected) + self.assertEqual(result.returncode, 0) class NsJailArgsTests(unittest.TestCase): diff --git a/tests/test_snekio.py b/tests/test_snekio.py new file mode 100644 index 0000000..8f04429 --- /dev/null +++ b/tests/test_snekio.py @@ -0,0 +1,58 @@ +from unittest import TestCase + +from snekbox import snekio +from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError + + +class SnekIOTests(TestCase): + def test_safe_path(self) -> None: + cases = [ + ("", ""), + ("foo", "foo"), + ("foo/bar", "foo/bar"), + ("foo/bar.ext", "foo/bar.ext"), + ] + + for path, expected in cases: + self.assertEqual(snekio.safe_path(path), expected) + + def test_safe_path_raise(self): + cases = [ + ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"), + ("/foo", IllegalPathError, "File path '/foo' must be relative"), + ] + + for path, error, msg in cases: + with self.assertRaises(error) as cm: + snekio.safe_path(path) + self.assertEqual(str(cm.exception), msg) + + def test_file_from_dict(self): + cases = [ + ({"path": "foo", "content": ""}, FileAttachment("foo", b"")), + ({"path": "foo"}, FileAttachment("foo", b"")), + ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")), + ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")), + ] + + for data, expected in cases: + self.assertEqual(FileAttachment.from_dict(data), expected) + + def test_file_from_dict_error(self): + cases = [ + ( + {"path": "foo", "content": "9"}, + ParsingError, + "Invalid base64 encoding for file 'foo'", + ), + ( + {"path": "var/a.txt", "content": "1="}, + ParsingError, + "Invalid base64 encoding for file 'var/a.txt'", + ), + ] + + for data, error, msg in cases: + with self.assertRaises(error) as cm: + FileAttachment.from_dict(data) + self.assertEqual(str(cm.exception), msg) |