aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--README.md21
-rw-r--r--config/snekbox.cfg8
-rw-r--r--docker-compose.yml2
-rw-r--r--snekbox/__init__.py2
-rw-r--r--snekbox/__main__.py2
-rw-r--r--snekbox/api/resources/eval.py82
-rw-r--r--snekbox/filesystem.py88
-rw-r--r--snekbox/memfs.py185
-rw-r--r--snekbox/nsjail.py137
-rw-r--r--snekbox/process.py32
-rw-r--r--snekbox/snekio.py99
-rw-r--r--snekbox/utils/__init__.py4
-rw-r--r--snekbox/utils/timed.py37
-rw-r--r--tests/api/__init__.py4
-rw-r--r--tests/api/test_eval.py106
-rw-r--r--tests/test_filesystem.py143
-rw-r--r--tests/test_integration.py81
-rw-r--r--tests/test_main.py2
-rw-r--r--tests/test_memfs.py65
-rw-r--r--tests/test_nsjail.py170
-rw-r--r--tests/test_snekio.py58
21 files changed, 1231 insertions, 97 deletions
diff --git a/README.md b/README.md
index 7540e21..3c3642a 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,9 @@
Python sandbox runners for executing code in isolation aka snekbox.
+Supports a memory [virtual read/write file system](#virtual-file-system) within the sandbox,
+allowing text or binary files to be sent and returned.
+
A client sends Python code to a snekbox, the snekbox executes the code, and finally the results of the execution are returned to the client.
```mermaid
@@ -60,10 +63,26 @@ The main features of the default configuration are:
* Memory limit
* Process count limit
* No networking
-* Restricted, read-only filesystem
+* Restricted, read-only system filesystem
+* Memory-based read-write filesystem mounted as working directory `/home`
NsJail is configured through [`snekbox.cfg`]. It contains the exact values for the items listed above. The configuration format is defined by a [protobuf file][7] which can be referred to for documentation. The command-line options of NsJail can also serve as documentation since they closely follow the config file format.
+### Memory File System
+
+On each execution, the host will mount an instance-specific `tmpfs` drive, this is used as a limited read-write folder for the sandboxed code. There is no access to other files or directories on the host container beyond the other read-only mounted system folders. Instance file systems are isolated; it is not possible for sandboxed code to access another instance's writeable directory.
+
+The following options for the memory file system are configurable as options in [gunicorn.conf.py](config/gunicorn.conf.py)
+
+* `memfs_instance_size` Size in bytes for the capacity of each instance file system.
+* `memfs_home` Path to the home directory within the instance file system.
+* `memfs_output` Path to the output directory within the instance file system.
+* `files_limit` Maximum number of valid output files to parse.
+* `files_timeout` Maximum time in seconds for output file parsing and encoding.
+* `files_pattern` Glob pattern to match files within `output`.
+
+The sandboxed code execution will start with a writeable working directory of `home`. By default, the output folder is also `home`. New files, and uploaded files with a newer last modified time, will be uploaded on completion.
+
### Gunicorn
[Gunicorn settings] can be found in [`gunicorn.conf.py`]. In the default configuration, the worker count, the bind address, and the WSGI app URI are likely the only things of any interest. Since it uses the default synchronous workers, the [worker count] effectively determines how many concurrent code evaluations can be performed.
diff --git a/config/snekbox.cfg b/config/snekbox.cfg
index 87c216e..5dd63da 100644
--- a/config/snekbox.cfg
+++ b/config/snekbox.cfg
@@ -3,7 +3,7 @@ description: "Execute Python"
mode: ONCE
hostname: "snekbox"
-cwd: "/snekbox"
+cwd: "/home"
time_limit: 6
@@ -16,10 +16,12 @@ envar: "VECLIB_MAXIMUM_THREADS=5"
envar: "NUMEXPR_NUM_THREADS=5"
envar: "PYTHONPATH=/snekbox/user_base/lib/python3.11/site-packages"
envar: "PYTHONIOENCODING=utf-8:strict"
+envar: "HOME=home"
keep_caps: false
rlimit_as: 700
+rlimit_fsize_type: INF
clone_newnet: true
clone_newuser: true
@@ -108,12 +110,12 @@ cgroup_mem_max: 52428800
cgroup_mem_swap_max: 0
cgroup_mem_mount: "/sys/fs/cgroup/memory"
-cgroup_pids_max: 5
+cgroup_pids_max: 6
cgroup_pids_mount: "/sys/fs/cgroup/pids"
iface_no_lo: true
exec_bin {
path: "/usr/local/bin/python"
- arg: "-Squ"
+ arg: "-BSqu"
}
diff --git a/docker-compose.yml b/docker-compose.yml
index 9d3ae71..0613abc 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -8,7 +8,7 @@ services:
image: ghcr.io/python-discord/snekbox${IMAGE_SUFFIX:--venv:dev}
pull_policy: never
ports:
- - 8060:8060
+ - "8060:8060"
init: true
ipc: none
tty: true
diff --git a/snekbox/__init__.py b/snekbox/__init__.py
index bed3692..b45960b 100644
--- a/snekbox/__init__.py
+++ b/snekbox/__init__.py
@@ -12,7 +12,7 @@ from snekbox.api import SnekAPI # noqa: E402
from snekbox.nsjail import NsJail # noqa: E402
from snekbox.utils.logging import init_logger, init_sentry # noqa: E402
-__all__ = ("NsJail", "SnekAPI")
+__all__ = ("NsJail", "SnekAPI", "DEBUG")
init_sentry(__version__)
init_logger(DEBUG)
diff --git a/snekbox/__main__.py b/snekbox/__main__.py
index 2382c4c..239f5d5 100644
--- a/snekbox/__main__.py
+++ b/snekbox/__main__.py
@@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
"""Evaluate Python code through NsJail."""
args = parse_args()
- result = NsJail().python3(args.code, nsjail_args=args.nsjail_args, py_args=args.py_args)
+ result = NsJail().python3(py_args=[*args.py_args, args.code], nsjail_args=args.nsjail_args)
print(result.stdout)
if result.returncode != 0:
diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py
index 1df6c1b..9a53577 100644
--- a/snekbox/api/resources/eval.py
+++ b/snekbox/api/resources/eval.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import falcon
@@ -7,6 +9,8 @@ from snekbox.nsjail import NsJail
__all__ = ("EvalResource",)
+from snekbox.snekio import FileAttachment, ParsingError
+
log = logging.getLogger(__name__)
@@ -25,8 +29,26 @@ class EvalResource:
"properties": {
"input": {"type": "string"},
"args": {"type": "array", "items": {"type": "string"}},
+ "files": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "path": {
+ "type": "string",
+ # Disallow starting with / or containing \0 anywhere
+ "pattern": r"^(?!/)(?!.*\\0).*$",
+ },
+ "content": {"type": "string"},
+ },
+ "required": ["path"],
+ },
+ },
},
- "required": ["input"],
+ "anyOf": [
+ {"required": ["input"]},
+ {"required": ["args"]},
+ ],
}
def __init__(self, nsjail: NsJail):
@@ -38,8 +60,11 @@ class EvalResource:
Evaluate Python code and return stdout, stderr, and the return code.
A list of arguments for the Python subprocess can be specified as `args`.
- Otherwise, the default argument "-c" is used to execute the input code.
- The input code is always passed as the last argument to Python.
+
+ If `input` is specified, it will be appended as the last argument to `args`,
+ and `args` will have a default argument of `"-c"`.
+
+ Either `input` or `args` must be specified.
The return codes mostly resemble those of a Unix shell. Some noteworthy cases:
@@ -53,15 +78,35 @@ class EvalResource:
Request body:
>>> {
- ... "input": "[i for i in range(1000)]",
- ... "args": ["-m", "timeit"] # This is optional
+ ... "input": "print('Hello')"
+ ... }
+
+ >>> {
+ ... "args": ["-c", "print('Hello')"]
+ ... }
+
+ >>> {
+ ... "args": ["main.py"],
+ ... "files": [
+ ... {
+ ... "path": "main.py",
+ ... "content": "SGVsbG8...=" # Base64
+ ... }
+ ... ]
... }
Response format:
>>> {
- ... "stdout": "10000 loops, best of 5: 23.8 usec per loop\n",
- ... "returncode": 0
+ ... "stdout": "10000 loops, best of 5: 23.8 usec per loop",
+ ... "returncode": 0,
+ ... "files": [
+ ... {
+ ... "path": "output.png",
+ ... "size": 57344,
+ ... "content": "eJzzSM3...=" # Base64
+ ... }
+ ... ]
... }
Status codes:
@@ -69,17 +114,28 @@ class EvalResource:
- 200
Successful evaluation; not indicative that the input code itself works
- 400
- Input's JSON schema is invalid
+ Input JSON schema is invalid
- 415
Unsupported content type; only application/JSON is supported
"""
- code = req.media["input"]
- args = req.media.get("args", ("-c",))
-
+ body: dict[str, str | list[str] | list[dict[str, str]]] = req.media
+ # If `input` is supplied, default `args` to `-c`
+ if "input" in body:
+ body.setdefault("args", ["-c"])
+ body["args"].append(body["input"])
try:
- result = self.nsjail.python3(code, py_args=args)
+ result = self.nsjail.python3(
+ py_args=body["args"],
+ files=[FileAttachment.from_dict(file) for file in body.get("files", [])],
+ )
+ except ParsingError as e:
+ raise falcon.HTTPBadRequest(title="Request file is invalid", description=str(e))
except Exception:
log.exception("An exception occurred while trying to process the request")
raise falcon.HTTPInternalServerError
- resp.media = {"stdout": result.stdout, "returncode": result.returncode}
+ resp.media = {
+ "stdout": result.stdout,
+ "returncode": result.returncode,
+ "files": [f.as_dict for f in result.files],
+ }
diff --git a/snekbox/filesystem.py b/snekbox/filesystem.py
new file mode 100644
index 0000000..312707c
--- /dev/null
+++ b/snekbox/filesystem.py
@@ -0,0 +1,88 @@
+"""Mounts and unmounts filesystems."""
+from __future__ import annotations
+
+import ctypes
+import os
+from ctypes.util import find_library
+from enum import IntEnum
+from pathlib import Path
+
+__all__ = ("mount", "unmount", "Size", "UnmountFlags")
+
+libc = ctypes.CDLL(find_library("c"), use_errno=True)
+libc.mount.argtypes = (
+ ctypes.c_char_p,
+ ctypes.c_char_p,
+ ctypes.c_char_p,
+ ctypes.c_ulong,
+ ctypes.c_char_p,
+)
+libc.umount2.argtypes = (ctypes.c_char_p, ctypes.c_int)
+
+
+class Size(IntEnum):
+ """Size multipliers for bytes."""
+
+ KiB = 1024
+ MiB = 1024**2
+ GiB = 1024**3
+ TiB = 1024**4
+
+
+class UnmountFlags(IntEnum):
+ """Flags for umount2."""
+
+ MNT_FORCE = 1
+ MNT_DETACH = 2
+ MNT_EXPIRE = 4
+ UMOUNT_NOFOLLOW = 8
+
+
+def mount(source: Path | str, target: Path | str, fs: str, **options: str | int) -> None:
+ """
+ Mount a filesystem.
+
+ https://man7.org/linux/man-pages/man8/mount.8.html
+
+ Args:
+ source: Source directory or device.
+ target: Target directory.
+ fs: Filesystem type.
+ **options: Mount options.
+
+ Raises:
+ OSError: On any mount error.
+ """
+ if Path(target).is_mount():
+ raise OSError(f"{target} is already a mount point")
+
+ kwargs = " ".join(f"{key}={value}" for key, value in options.items())
+
+ result: int = libc.mount(
+ str(source).encode(), str(target).encode(), fs.encode(), 0, kwargs.encode()
+ )
+ if result < 0:
+ errno = ctypes.get_errno()
+ raise OSError(errno, f"Error mounting {target}: {os.strerror(errno)}")
+
+
+def unmount(target: Path | str, flags: UnmountFlags | int = UnmountFlags.MNT_DETACH) -> None:
+ """
+ Unmount a filesystem.
+
+ https://man7.org/linux/man-pages/man2/umount.2.html
+
+ Args:
+ target: Target directory.
+ flags: Unmount flags.
+
+ Raises:
+ OSError: On any unmount error.
+ """
+ if not Path(target).is_mount():
+ raise OSError(f"{target} is not a mount point")
+
+ result: int = libc.umount2(str(target).encode(), int(flags))
+ if result < 0:
+ errno = ctypes.get_errno()
+ raise OSError(errno, f"Error unmounting {target}: {os.strerror(errno)}")
diff --git a/snekbox/memfs.py b/snekbox/memfs.py
new file mode 100644
index 0000000..f32fed1
--- /dev/null
+++ b/snekbox/memfs.py
@@ -0,0 +1,185 @@
+"""Memory filesystem for snekbox."""
+from __future__ import annotations
+
+import logging
+import warnings
+import weakref
+from collections.abc import Generator
+from contextlib import suppress
+from pathlib import Path
+from types import TracebackType
+from typing import Type
+from uuid import uuid4
+
+from snekbox.filesystem import mount, unmount
+from snekbox.snekio import FileAttachment
+
+log = logging.getLogger(__name__)
+
+__all__ = ("MemFS",)
+
+
+class MemFS:
+ """An in-memory temporary file system."""
+
+ def __init__(
+ self,
+ instance_size: int,
+ root_dir: str | Path = "/memfs",
+ home: str = "home",
+ output: str = "home",
+ ) -> None:
+ """
+ Initialize an in-memory temporary file system.
+
+ Examples:
+ >>> with MemFS(1024) as memfs:
+ ... (memfs.home / "test.txt").write_text("Hello")
+
+ Args:
+ instance_size: Size limit of each tmpfs instance in bytes.
+ root_dir: Root directory to mount instances in.
+ home: Name of the home directory.
+ output: Name of the output directory within home. If empty, uses home.
+ """
+ self.instance_size = instance_size
+ self.root_dir = Path(root_dir)
+ self.root_dir.mkdir(exist_ok=True, parents=True)
+ self._home_name = home
+ self._output_name = output
+
+ for _ in range(10):
+ name = str(uuid4())
+ try:
+ self.path = self.root_dir / name
+ self.path.mkdir()
+ mount("", self.path, "tmpfs", size=self.instance_size)
+ break
+ except OSError:
+ continue
+ else:
+ raise RuntimeError("Failed to generate a unique MemFS name in 10 attempts")
+
+ self.mkdir(self.home)
+ self.mkdir(self.output)
+
+ self._finalizer = weakref.finalize(
+ self,
+ self._cleanup,
+ self.path,
+ warn_message=f"Implicitly cleaning up {self!r}",
+ )
+
+ @classmethod
+ def _cleanup(cls, path: Path, warn_message: str):
+ """Implicit cleanup of the MemFS."""
+ with suppress(OSError):
+ unmount(path)
+ path.rmdir()
+ warnings.warn(warn_message, ResourceWarning)
+
+ def cleanup(self) -> None:
+ """Unmount the tempfs and remove the directory."""
+ if self._finalizer.detach() or self.path.exists():
+ unmount(self.path)
+ self.path.rmdir()
+
+ @property
+ def name(self) -> str:
+ """Name of the temp dir."""
+ return self.path.name
+
+ @property
+ def home(self) -> Path:
+ """Path to home directory."""
+ return self.path / self._home_name
+
+ @property
+ def output(self) -> Path:
+ """Path to output directory."""
+ return self.path / self._output_name
+
+ def __enter__(self) -> MemFS:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ self.cleanup()
+
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} {self.path}>"
+
+ def mkdir(self, path: Path | str, chmod: int = 0o777) -> Path:
+ """Create a directory in the tempdir."""
+ folder = Path(self.path, path)
+ folder.mkdir(parents=True, exist_ok=True)
+ folder.chmod(chmod)
+ return folder
+
+ def files(
+ self,
+ limit: int,
+ pattern: str = "**/*",
+ exclude_files: dict[Path, float] | None = None,
+ ) -> Generator[FileAttachment, None, None]:
+ """
+ Yields FileAttachments for files found in the output directory.
+
+ Args:
+ limit: The maximum number of files to parse.
+ pattern: The glob pattern to match files against.
+ exclude_files: A dict of Paths and last modified times.
+ Files will be excluded if their last modified time
+ is equal to the provided value.
+ """
+ count = 0
+ for file in self.output.rglob(pattern):
+ if exclude_files and (orig_time := exclude_files.get(file)):
+ new_time = file.stat().st_mtime
+ log.info(f"Checking {file.name} ({orig_time=}, {new_time=})")
+ if file.stat().st_mtime == orig_time:
+ log.info(f"Skipping {file.name!r} as it has not been modified")
+ continue
+
+ if count > limit:
+ log.info(f"Max attachments {limit} reached, skipping remaining files")
+ break
+
+ if file.is_file():
+ count += 1
+ log.info(f"Found valid file for upload {file.name!r}")
+ yield FileAttachment.from_path(file, relative_to=self.output)
+
+ def files_list(
+ self,
+ limit: int,
+ pattern: str,
+ exclude_files: dict[Path, float] | None = None,
+ preload_dict: bool = False,
+ ) -> list[FileAttachment]:
+ """
+ Return a sorted list of file paths within the output directory.
+
+ Args:
+ limit: The maximum number of files to parse.
+ pattern: The glob pattern to match files against.
+ exclude_files: A dict of Paths and last modified times.
+ Files will be excluded if their last modified time
+ is equal to the provided value.
+ preload_dict: Whether to preload as_dict property data.
+ Returns:
+ List of FileAttachments sorted lexically by path name.
+ """
+ res = sorted(
+ self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
+ key=lambda f: f.path,
+ )
+ if preload_dict:
+ for file in res:
+ # Loads the cached property as attribute
+ _ = file.as_dict
+ return res
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index 63afdef..f014850 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -2,26 +2,43 @@ import logging
import re
import subprocess
import sys
-import textwrap
-from subprocess import CompletedProcess
+from collections.abc import Generator
+from pathlib import Path
from tempfile import NamedTemporaryFile
-from typing import Iterable
+from typing import Iterable, TypeVar
from google.protobuf import text_format
from snekbox import DEBUG, utils
from snekbox.config_pb2 import NsJailConfig
+from snekbox.filesystem import Size
+from snekbox.memfs import MemFS
+from snekbox.process import EvalResult
+from snekbox.snekio import FileAttachment
+from snekbox.utils.timed import timed
__all__ = ("NsJail",)
log = logging.getLogger(__name__)
+_T = TypeVar("_T")
+
# [level][timestamp][PID]? function_signature:line_no? message
LOG_PATTERN = re.compile(
r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)"
)
+def iter_lstrip(iterable: Iterable[_T]) -> Generator[_T, None, None]:
+ """Remove leading falsy objects from an iterable."""
+ it = iter(iterable)
+ for item in it:
+ if item:
+ yield item
+ break
+ yield from it
+
+
class NsJail:
"""
Core Snekbox functionality, providing safe execution of Python code.
@@ -35,12 +52,41 @@ class NsJail:
config_path: str = "./config/snekbox.cfg",
max_output_size: int = 1_000_000,
read_chunk_size: int = 10_000,
+ memfs_instance_size: int = 48 * Size.MiB,
+ memfs_home: str = "home",
+ memfs_output: str = "home",
+ files_limit: int | None = 100,
+ files_timeout: float | None = 8,
+ files_pattern: str = "**/[!_]*",
):
+ """
+ Initialize NsJail.
+
+ Args:
+ nsjail_path: Path to the NsJail binary.
+ config_path: Path to the NsJail configuration file.
+ max_output_size: Maximum size of the output in bytes.
+ read_chunk_size: Size of the read buffer in bytes.
+ memfs_instance_size: Size of the tmpfs instance in bytes.
+ memfs_home: Name of the mounted home directory.
+ memfs_output: Name of the output directory within home,
+ can be empty to use home as output.
+ files_limit: Maximum number of output files to parse.
+ files_timeout: Maximum time in seconds to wait for output files to be read.
+ files_pattern: Pattern to match files to attach within the output directory.
+ """
self.nsjail_path = nsjail_path
self.config_path = config_path
self.max_output_size = max_output_size
self.read_chunk_size = read_chunk_size
+ self.memfs_instance_size = memfs_instance_size
+ self.memfs_home = memfs_home
+ self.memfs_output = memfs_output
+ self.files_limit = files_limit
+ self.files_timeout = files_timeout
+ self.files_pattern = files_pattern
+
self.config = self._read_config(config_path)
self.cgroup_version = utils.cgroup.init(self.config)
self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version)
@@ -129,16 +175,18 @@ class NsJail:
return "".join(output)
def python3(
- self, code: str, *, nsjail_args: Iterable[str] = (), py_args: Iterable[str] = ("-c",)
- ) -> CompletedProcess:
+ self,
+ py_args: Iterable[str],
+ files: Iterable[FileAttachment] = (),
+ nsjail_args: Iterable[str] = (),
+ ) -> EvalResult:
"""
Execute Python 3 code in an isolated environment and return the completed process.
- The `nsjail_args` passed will be used to override the values in the NsJail config.
- These arguments are only options for NsJail; they do not affect Python's arguments.
-
- `py_args` are arguments to pass to the Python subprocess before the code,
- which is the last argument. By default, it's "-c", which executes the code given.
+ Args:
+ py_args: Arguments to pass to Python.
+ files: FileAttachments to write to the sandbox prior to running Python.
+ nsjail_args: Overrides for the NsJail configuration.
"""
if self.cgroup_version == 2:
nsjail_args = ("--use_cgroupv2", *nsjail_args)
@@ -152,8 +200,19 @@ class NsJail:
*nsjail_args,
)
- with NamedTemporaryFile() as nsj_log:
- args = (
+ with NamedTemporaryFile() as nsj_log, MemFS(
+ instance_size=self.memfs_instance_size,
+ home=self.memfs_home,
+ output=self.memfs_output,
+ ) as fs:
+ nsjail_args = (
+ # Mount `home` with Read/Write access
+ "--bindmount",
+ f"{fs.home}:home",
+ *nsjail_args,
+ )
+
+ args = [
self.nsjail_path,
"--config",
self.config_path,
@@ -163,13 +222,30 @@ class NsJail:
"--",
self.config.exec_bin.path,
*self.config.exec_bin.arg,
- *py_args,
- code,
- )
+ # Filter out empty strings at start of py_args
+ # (causes issues with python cli)
+ *iter_lstrip(py_args),
+ ]
+
+ # Write provided files if any
+ files_written: dict[Path, float] = {}
+ for file in files:
+ try:
+ f_path = file.save_to(fs.home)
+ # Allow file to be writable
+ f_path.chmod(0o777)
+ # Save the written at time to later check if it was modified
+ files_written[f_path] = f_path.stat().st_mtime
+ log.info(f"Created file at {(fs.home / file.path)!r}.")
+ except OSError as e:
+ log.info(f"Failed to create file at {(fs.home / file.path)!r}.", exc_info=e)
+ return EvalResult(
+ args, None, f"{e.__class__.__name__}: Failed to create file '{file.path}'."
+ )
msg = "Executing code..."
if DEBUG:
- msg = f"{msg[:-3]}:\n{textwrap.indent(code, ' ')}\nWith the arguments {args}."
+ msg = f"{msg[:-3]} with the arguments {args}."
log.info(msg)
try:
@@ -177,23 +253,36 @@ class NsJail:
args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
except ValueError:
- return CompletedProcess(args, None, "ValueError: embedded null byte", None)
+ return EvalResult(args, None, "ValueError: embedded null byte")
try:
output = self._consume_stdout(nsjail)
except UnicodeDecodeError:
- return CompletedProcess(
- args,
- None,
- "UnicodeDecodeError: invalid Unicode in output pipe",
- None,
- )
+ return EvalResult(args, None, "UnicodeDecodeError: invalid Unicode in output pipe")
# When you send signal `N` to a subprocess to terminate it using Popen, it
# will return `-N` as its exit code. As we normally get `N + 128` back, we
# convert negative exit codes to the `N + 128` form.
returncode = -nsjail.returncode + 128 if nsjail.returncode < 0 else nsjail.returncode
+ # Parse attachments with time limit
+ try:
+ attachments = timed(
+ MemFS.files_list,
+ (fs, self.files_limit, self.files_pattern),
+ {
+ "preload_dict": True,
+ "exclude_files": files_written,
+ },
+ timeout=self.files_timeout,
+ )
+ log.info(f"Found {len(attachments)} files.")
+ except TimeoutError as e:
+ log.info(f"Exceeded time limit while parsing attachments: {e}")
+ return EvalResult(
+ args, None, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
+
log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
# NsJail probably failed to parse arguments so log output will still be in stdout
@@ -203,4 +292,4 @@ class NsJail:
log.info(f"nsjail return code: {returncode}")
- return CompletedProcess(args, returncode, output, None)
+ return EvalResult(args, returncode, output, files=attachments)
diff --git a/snekbox/process.py b/snekbox/process.py
new file mode 100644
index 0000000..552b91a
--- /dev/null
+++ b/snekbox/process.py
@@ -0,0 +1,32 @@
+"""Utilities for process management."""
+from collections.abc import Sequence
+from os import PathLike
+from subprocess import CompletedProcess
+from typing import TypeVar
+
+from snekbox.snekio import FileAttachment
+
+_T = TypeVar("_T")
+ArgType = (
+ str
+ | bytes
+ | PathLike[str]
+ | PathLike[bytes]
+ | Sequence[str | bytes | PathLike[str] | PathLike[bytes]]
+)
+
+
+class EvalResult(CompletedProcess[_T]):
+ """An evaluation job that has finished running."""
+
+ def __init__(
+ self,
+ args: ArgType,
+ returncode: int | None,
+ stdout: _T | None = None,
+ stderr: _T | None = None,
+ files: list[FileAttachment] | None = None,
+ ) -> None:
+ """Create an evaluation result."""
+ super().__init__(args, returncode, stdout, stderr)
+ self.files: list[FileAttachment] = files or []
diff --git a/snekbox/snekio.py b/snekbox/snekio.py
new file mode 100644
index 0000000..821f057
--- /dev/null
+++ b/snekbox/snekio.py
@@ -0,0 +1,99 @@
+"""I/O Operations for sending / receiving files from the sandbox."""
+from __future__ import annotations
+
+from base64 import b64decode, b64encode
+from dataclasses import dataclass
+from functools import cached_property
+from pathlib import Path
+
+
+def safe_path(path: str) -> str:
+ """
+ Return `path` if there are no security issues.
+
+ Raises:
+ IllegalPathError: Raised on any path rule violation.
+ """
+ # Disallow absolute paths
+ if Path(path).is_absolute():
+ raise IllegalPathError(f"File path '{path}' must be relative")
+
+ # Disallow traversal beyond root
+ try:
+ test_root = Path("/home")
+ Path(test_root).joinpath(path).resolve().relative_to(test_root.resolve())
+ except ValueError:
+ raise IllegalPathError(f"File path '{path}' may not traverse beyond root")
+
+ return path
+
+
+class ParsingError(ValueError):
+ """Raised when an incoming content cannot be parsed."""
+
+
+class IllegalPathError(ParsingError):
+ """Raised when a request file has an illegal path."""
+
+
+@dataclass(frozen=True)
+class FileAttachment:
+ """A file attachment."""
+
+ path: str
+ content: bytes
+
+ def __repr__(self) -> str:
+ path = f"{self.path[:30]}..." if len(self.path) > 30 else self.path
+ content = f"{self.content[:15]}..." if len(self.content) > 15 else self.content
+ return f"{self.__class__.__name__}(path={path!r}, content={content!r})"
+
+ @classmethod
+ def from_dict(cls, data: dict[str, str]) -> FileAttachment:
+ """
+ Convert a dict to an attachment.
+
+ Raises:
+ ParsingError: Raised when the dict has invalid base64 `content`.
+ """
+ path = safe_path(data["path"])
+ try:
+ content = b64decode(data.get("content", ""))
+ except (TypeError, ValueError) as e:
+ raise ParsingError(f"Invalid base64 encoding for file '{path}'") from e
+ return cls(path, content)
+
+ @classmethod
+ def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment:
+ """
+ Create an attachment from a file path.
+
+ Args:
+ file: The file to attach.
+ relative_to: The root for the path name.
+ """
+ path = file.relative_to(relative_to) if relative_to else file
+ return cls(str(path), file.read_bytes())
+
+ @property
+ def size(self) -> int:
+ """Size of the attachment."""
+ return len(self.content)
+
+ def save_to(self, directory: Path | str) -> Path:
+ """Write the attachment to a file in `directory`. Return a Path of the file."""
+ file = Path(directory, self.path)
+ # Create directories if they don't exist
+ file.parent.mkdir(parents=True, exist_ok=True)
+ file.write_bytes(self.content)
+ return file
+
+ @cached_property
+ def as_dict(self) -> dict[str, str | int]:
+ """Convert the attachment to a dict."""
+ content = b64encode(self.content).decode("ascii")
+ return {
+ "path": self.path,
+ "size": self.size,
+ "content": content,
+ }
diff --git a/snekbox/utils/__init__.py b/snekbox/utils/__init__.py
index 6d6bc32..010fa65 100644
--- a/snekbox/utils/__init__.py
+++ b/snekbox/utils/__init__.py
@@ -1,3 +1,3 @@
-from . import cgroup, logging, swap
+from . import cgroup, logging, swap, timed
-__all__ = ("cgroup", "logging", "swap")
+__all__ = ("cgroup", "logging", "swap", "timed")
diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py
new file mode 100644
index 0000000..02388ff
--- /dev/null
+++ b/snekbox/utils/timed.py
@@ -0,0 +1,37 @@
+"""Calling functions with time limits."""
+import multiprocessing
+from collections.abc import Callable, Iterable, Mapping
+from typing import Any, TypeVar
+
+_T = TypeVar("_T")
+_V = TypeVar("_V")
+
+__all__ = ("timed",)
+
+
+def timed(
+ func: Callable[[_T], _V],
+ args: Iterable = (),
+ kwds: Mapping[str, Any] | None = None,
+ timeout: float | None = None,
+) -> _V:
+ """
+ Call a function with a time limit.
+
+ Args:
+ func: Function to call.
+ args: Arguments for function.
+ kwds: Keyword arguments for function.
+ timeout: Timeout limit in seconds.
+
+ Raises:
+ TimeoutError: If the function call takes longer than `timeout` seconds.
+ """
+ if kwds is None:
+ kwds = {}
+ with multiprocessing.Pool(1) as pool:
+ result = pool.apply_async(func, args, kwds)
+ try:
+ return result.get(timeout)
+ except multiprocessing.TimeoutError as e:
+ raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
index 0e6e422..5f20faf 100644
--- a/tests/api/__init__.py
+++ b/tests/api/__init__.py
@@ -1,10 +1,10 @@
import logging
-from subprocess import CompletedProcess
from unittest import mock
from falcon import testing
from snekbox.api import SnekAPI
+from snekbox.process import EvalResult
class SnekAPITestCase(testing.TestCase):
@@ -13,7 +13,7 @@ class SnekAPITestCase(testing.TestCase):
self.patcher = mock.patch("snekbox.api.snekapi.NsJail", autospec=True)
self.mock_nsjail = self.patcher.start()
- self.mock_nsjail.return_value.python3.return_value = CompletedProcess(
+ self.mock_nsjail.return_value.python3.return_value = EvalResult(
args=[], returncode=0, stdout="output", stderr="error"
)
self.addCleanup(self.patcher.stop)
diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py
index 976970e..37f90e7 100644
--- a/tests/api/test_eval.py
+++ b/tests/api/test_eval.py
@@ -5,12 +5,19 @@ class TestEvalResource(SnekAPITestCase):
PATH = "/eval"
def test_post_valid_200(self):
- body = {"input": "foo"}
- result = self.simulate_post(self.PATH, json=body)
-
- self.assertEqual(result.status_code, 200)
- self.assertEqual("output", result.json["stdout"])
- self.assertEqual(0, result.json["returncode"])
+ cases = [
+ {"args": ["-c", "print('output')"]},
+ {"input": "print('hello')"},
+ {"input": "print('hello')", "args": ["-c"]},
+ {"input": "print('hello')", "args": [""]},
+ {"input": "pass", "args": ["-m", "timeit"]},
+ ]
+ for body in cases:
+ with self.subTest():
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 200)
+ self.assertEqual("output", result.json["stdout"])
+ self.assertEqual(0, result.json["returncode"])
def test_post_invalid_schema_400(self):
body = {"stuff": "foo"}
@@ -20,27 +27,100 @@ class TestEvalResource(SnekAPITestCase):
expected = {
"title": "Request data failed validation",
- "description": "'input' is a required property",
+ "description": "{'stuff': 'foo'} is not valid under any of the given schemas",
}
self.assertEqual(expected, result.json)
def test_post_invalid_data_400(self):
- bodies = ({"input": 400}, {"input": "", "args": [400]})
-
- for body in bodies:
+ bodies = ({"args": 400}, {"args": [], "files": [215]})
+ expects = ["400 is not of type 'array'", "215 is not of type 'object'"]
+ for body, expected in zip(bodies, expects):
with self.subTest():
result = self.simulate_post(self.PATH, json=body)
self.assertEqual(result.status_code, 400)
- expected = {
+ expected_json = {
"title": "Request data failed validation",
- "description": "400 is not of type 'string'",
+ "description": expected,
+ }
+ self.assertEqual(expected_json, result.json)
+
+ def test_files_path(self):
+ """Normal paths should work with 200."""
+ test_paths = [
+ "file.txt",
+ "./0.jpg",
+ "path/to/file",
+ "folder/../hm",
+ "folder/./to/./somewhere",
+ "traversal/but/../not/beyond/../root",
+ r"backslash\\okay",
+ r"backslash\okay",
+ "numbers/0123456789",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 200)
+ self.assertEqual("output", result.json["stdout"])
+ self.assertEqual(0, result.json["returncode"])
+
+ def test_files_illegal_path_traversal(self):
+ """Traversal beyond root should be denied with 400 error."""
+ test_paths = [
+ "../secrets",
+ "../../dir",
+ "dir/../../secrets",
+ "dir/var/../../../file",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ expected = {
+ "title": "Request file is invalid",
+ "description": f"File path '{path}' may not traverse beyond root",
}
-
self.assertEqual(expected, result.json)
+ def test_files_illegal_path_absolute(self):
+ """Absolute file paths should 400-error at json schema validation stage."""
+ test_paths = [
+ "/",
+ "/etc",
+ "/etc/vars/secrets",
+ "/absolute",
+ "/file.bin",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ self.assertEqual("Request data failed validation", result.json["title"])
+ self.assertIn("does not match", result.json["description"])
+
+ def test_files_illegal_path_null_byte(self):
+ """Paths containing \0 should 400-error at json schema validation stage."""
+ test_paths = [
+ r"etc/passwd\0",
+ r"a\0b",
+ r"\0",
+ r"\\0",
+ r"var/\0/path",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ self.assertEqual("Request data failed validation", result.json["title"])
+ self.assertIn("does not match", result.json["description"])
+
def test_post_invalid_content_type_415(self):
body = "{'input': 'foo'}"
headers = {"Content-Type": "application/xml"}
diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py
new file mode 100644
index 0000000..e4d081f
--- /dev/null
+++ b/tests/test_filesystem.py
@@ -0,0 +1,143 @@
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager, suppress
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+from uuid import uuid4
+
+from snekbox.filesystem import UnmountFlags, mount, unmount
+
+
+class LibMountTests(TestCase):
+ temp_dir: TemporaryDirectory
+
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+
+ @contextmanager
+ def get_mount(self):
+ """Yield a valid mount point and unmount after context."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ yield path
+ finally:
+ with suppress(OSError):
+ unmount(path)
+
+ def test_mount(self):
+ """Test normal mounting."""
+ with self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ self.assertTrue(path.exists())
+ self.assertFalse(path.is_mount())
+ # Unmounting should not remove the original folder
+ self.assertTrue(path.exists())
+
+ def test_mount_errors(self):
+ """Test invalid mount errors."""
+ cases = [
+ (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"),
+ (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"),
+ (
+ dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"),
+ OSError,
+ "Invalid argument",
+ ),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ mount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_mount_duplicate(self):
+ """Test attempted mount after mounted."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ with self.assertRaises(OSError) as cm:
+ mount(source="", target=path, fs="tmpfs")
+ self.assertIn("already a mount point", str(cm.exception))
+ finally:
+ unmount(target=path)
+
+ def test_unmount_flags(self):
+ """Test unmount flags."""
+ flags = [
+ UnmountFlags.MNT_FORCE,
+ UnmountFlags.MNT_DETACH,
+ UnmountFlags.UMOUNT_NOFOLLOW,
+ ]
+ for flag in flags:
+ with self.subTest(flag=flag), self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ unmount(path, flag)
+ self.assertFalse(path.is_mount())
+
+ def test_unmount_flags_expire(self):
+ """Test unmount MNT_EXPIRE behavior."""
+ with self.get_mount() as path:
+ with self.assertRaises(BlockingIOError):
+ unmount(path, UnmountFlags.MNT_EXPIRE)
+
+ def test_unmount_errors(self):
+ """Test invalid unmount errors."""
+ cases = [
+ (dict(target="not/exist"), OSError, "is not a mount point"),
+ (dict(target=Path("not/exist")), OSError, "is not a mount point"),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ unmount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_unmount_invalid_args(self):
+ """Test invalid unmount invalid flag."""
+ with self.get_mount() as path:
+ with self.assertRaises(OSError) as cm:
+ unmount(path, 251)
+ self.assertIn("Invalid argument", str(cm.exception))
+
+ def test_threading(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)]
+
+ for path in paths:
+ path.mkdir()
+ self.assertFalse(path.is_mount())
+
+ try:
+ with ThreadPoolExecutor() as pool:
+ res = list(
+ pool.map(
+ mount,
+ [""] * len(paths),
+ paths,
+ ["tmpfs"] * len(paths),
+ )
+ )
+ self.assertEqual(len(res), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertTrue(path.is_mount())
+
+ unmounts = list(pool.map(unmount, paths))
+ self.assertEqual(len(unmounts), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertFalse(path.is_mount())
+ finally:
+ with suppress(OSError):
+ for path in paths:
+ unmount(path)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 7c5db2b..91b01e6 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -1,14 +1,25 @@
import json
import unittest
import urllib.request
+from base64 import b64encode
from multiprocessing.dummy import Pool
+from textwrap import dedent
from tests.gunicorn_utils import run_gunicorn
-def run_code_in_snekbox(code: str) -> tuple[str, int]:
- body = {"input": code}
- json_data = json.dumps(body).encode("utf-8")
+def b64encode_code(data: str):
+ data = dedent(data).strip()
+ return b64encode(data.encode()).decode("ascii")
+
+
+def snekbox_run_code(code: str) -> tuple[str, int]:
+ body = {"args": ["-c", code]}
+ return snekbox_request(body)
+
+
+def snekbox_request(content: dict) -> tuple[str, int]:
+ json_data = json.dumps(content).encode("utf-8")
req = urllib.request.Request("http://localhost:8060/eval")
req.add_header("Content-Type", "application/json; charset=utf-8")
@@ -34,9 +45,71 @@ class IntegrationTests(unittest.TestCase):
args = [code] * processes
with Pool(processes) as p:
- results = p.map(run_code_in_snekbox, args)
+ results = p.map(snekbox_run_code, args)
responses, statuses = zip(*results)
self.assertTrue(all(status == 200 for status in statuses))
self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses))
+
+ def test_eval(self):
+ """Test normal eval requests without files."""
+ with run_gunicorn():
+ cases = [
+ ({"input": "print('Hello')"}, "Hello\n"),
+ ({"args": ["-c", "print('abc12')"]}, "abc12\n"),
+ ]
+ for body, expected in cases:
+ with self.subTest(body=body):
+ response, status = snekbox_request(body)
+ self.assertEqual(status, 200)
+ self.assertEqual(json.loads(response)["stdout"], expected)
+
+ def test_files_send_receive(self):
+ """Test sending and receiving files to snekbox."""
+ with run_gunicorn():
+ request = {
+ "args": ["main.py"],
+ "files": [
+ {
+ "path": "main.py",
+ "content": b64encode_code(
+ """
+ from pathlib import Path
+ from mod import lib
+ print(lib.var)
+
+ with open('test.txt', 'w') as f:
+ f.write('test 1')
+
+ Path('dir').mkdir()
+ Path('dir/test2.txt').write_text('test 2')
+ """
+ ),
+ },
+ {"path": "mod/__init__.py"},
+ {"path": "mod/lib.py", "content": b64encode_code("var = 'hello'")},
+ ],
+ }
+
+ expected = {
+ "stdout": "hello\n",
+ "returncode": 0,
+ "files": [
+ {
+ "path": "dir/test2.txt",
+ "size": len("test 2"),
+ "content": b64encode_code("test 2"),
+ },
+ {
+ "path": "test.txt",
+ "size": len("test 1"),
+ "content": b64encode_code("test 1"),
+ },
+ ],
+ }
+
+ response, status = snekbox_request(request)
+
+ self.assertEqual(200, status)
+ self.assertEqual(expected, json.loads(response))
diff --git a/tests/test_main.py b/tests/test_main.py
index 77b3130..24c067c 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase):
@patch("sys.argv", ["", "import sys; sys.exit(22)"])
def test_main_exits_with_returncode(self):
- """Should exit with the subprocess's returncode if it's non-zero."""
+ """Should exit with the subprocess returncode if it's non-zero."""
with self.assertRaises(SystemExit) as cm:
snekbox_main.main()
diff --git a/tests/test_memfs.py b/tests/test_memfs.py
new file mode 100644
index 0000000..0555726
--- /dev/null
+++ b/tests/test_memfs.py
@@ -0,0 +1,65 @@
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import ExitStack
+from unittest import TestCase, mock
+from uuid import uuid4
+
+from snekbox.memfs import MemFS
+
+UUID_TEST = uuid4()
+
+
+class MemFSTests(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.logger = logging.getLogger("snekbox.memfs")
+ self.logger.setLevel(logging.WARNING)
+
+ @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST)
+ def test_assignment_thread_safe(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ # Concurrently create MemFS in threads, check only 1 can be created
+ # Others should result in RuntimeError
+ with ExitStack() as stack:
+ with ThreadPoolExecutor() as executor:
+ memfs: MemFS | None = None
+ # Each future uses enter_context to ensure __exit__ on test exception
+ futures = [
+ executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8)
+ ]
+ for future in futures:
+ # We should have exactly one result and all others RuntimeErrors
+ if err := future.exception():
+ self.assertIsInstance(err, RuntimeError)
+ else:
+ self.assertIsNone(memfs)
+ memfs = future.result()
+
+ # Original memfs should still exist afterwards
+ self.assertIsInstance(memfs, MemFS)
+ self.assertTrue(memfs.path.is_mount())
+
+ def test_cleanup(self):
+ """Test explicit cleanup."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ memfs.cleanup()
+ self.assertFalse(path.exists())
+
+ def test_context_cleanup(self):
+ """Context __exit__ should trigger cleanup."""
+ with MemFS(10) as memfs:
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ self.assertFalse(path.exists())
+
+ def test_implicit_cleanup(self):
+ """Test implicit _cleanup triggered by GC."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ # Catch the warning about implicit cleanup
+ with self.assertWarns(ResourceWarning):
+ del memfs
+ self.assertFalse(path.exists())
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 6f6e2a7..cad79f3 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -9,28 +9,40 @@ from itertools import product
from pathlib import Path
from textwrap import dedent
+from snekbox.filesystem import Size
from snekbox.nsjail import NsJail
+from snekbox.snekio import FileAttachment
class NsJailTests(unittest.TestCase):
def setUp(self):
super().setUp()
- self.nsjail = NsJail()
+ # Specify lower limits for unit tests to complete within time limits
+ self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB)
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)
+ def eval_code(self, code: str):
+ return self.nsjail.python3(["-c", code])
+
+ def eval_file(self, code: str, name: str = "test.py", **kwargs):
+ file = FileAttachment(name, code.encode())
+ return self.nsjail.python3([name], [file], **kwargs)
+
def test_print_returns_0(self):
- result = self.nsjail.python3("print('test')")
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.stdout, "test\n")
- self.assertEqual(result.stderr, None)
+ for fn in (self.eval_code, self.eval_file):
+ with self.subTest(fn.__name__):
+ result = fn("print('test')")
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "test\n")
+ self.assertEqual(result.stderr, None)
def test_timeout_returns_137(self):
code = "while True: pass"
with self.assertLogs(self.logger) as log:
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 137)
self.assertEqual(result.stdout, "")
@@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase):
# Add a kilobyte just to be safe.
code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}"
- result = self.nsjail.python3(code)
- self.assertEqual(result.returncode, 137)
+ result = self.eval_file(code)
self.assertEqual(result.stdout, "")
+ self.assertEqual(result.returncode, 137)
+ self.assertEqual(result.stderr, None)
+
+ def test_multi_files(self):
+ files = [
+ FileAttachment("main.py", "import lib; print(lib.x)".encode()),
+ FileAttachment("lib.py", "x = 'hello'".encode()),
+ ]
+
+ result = self.nsjail.python3(["main.py"], files)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
self.assertEqual(result.stderr, None)
def test_subprocess_resource_unavailable(self):
+ max_pids = self.nsjail.config.cgroup_pids_max
code = dedent(
- """
+ f"""
import subprocess
- # Max PIDs is 5.
- for _ in range(6):
+ # Should fail at n (max PIDs) since the caller python process counts as well
+ for _ in range({max_pids}):
print(subprocess.Popen(
[
'/usr/local/bin/python3',
@@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
+ # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n"
+ expected = "\n".join(map(str, range(2, max_pids + 1)))
+ self.assertIn(expected, result.stdout)
self.assertEqual(result.stderr, None)
def test_multiprocess_resource_limits(self):
@@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase):
"""
)
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
exit_codes = result.stdout.strip().split()
self.assertIn("-9", exit_codes)
@@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Read-only file system", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_write(self):
+ code = dedent(
+ """
+ from pathlib import Path
+ with open('test.txt', 'w') as f:
+ f.write('hello')
+ print(Path('test.txt').read_text())
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
+ self.assertEqual(result.stderr, None)
+
+ def test_write_exceed_space(self):
+ code = dedent(
+ f"""
+ size = {self.nsjail.memfs_instance_size} // 2048
+ with open('f.bin', 'wb') as f:
+ for i in range(size):
+ f.write(b'1' * 2048)
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 1)
+ self.assertIn("No space left on device", result.stdout)
+ self.assertEqual(result.stderr, None)
+
def test_forkbomb_resource_unavailable(self):
code = dedent(
"""
@@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_timeout(self):
+ code = dedent(
+ """
+ import os
+ data = "a" * 1024
+ size = 32 * 1024 * 1024
+
+ with open("file", "w") as f:
+ for _ in range((size // 1024) - 5):
+ f.write(data)
+
+ for i in range(100):
+ os.symlink("file", f"file{i}")
+ """
+ ).strip()
+
+ nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, None)
+ self.assertEqual(
+ result.stdout, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
+ self.assertEqual(result.stderr, None)
+
+ def test_file_write_error(self):
+ """Test errors during file write."""
+ result = self.nsjail.python3(
+ [""],
+ [
+ FileAttachment("dir/test.txt", b"abc"),
+ FileAttachment("dir", b"xyz"),
+ ],
+ )
+
+ self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.")
+ self.assertEqual(result.stderr, None)
+ self.assertEqual(result.returncode, None)
+
def test_sigsegv_returns_139(self): # In honour of Juan.
code = dedent(
"""
@@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 139)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
def test_null_byte_value_error(self):
- result = self.nsjail.python3("\0")
+ # This error only occurs with `-c` mode
+ result = self.eval_code("\0")
self.assertEqual(result.returncode, None)
self.assertEqual(result.stdout, "ValueError: embedded null byte")
self.assertEqual(result.stderr, None)
def test_print_bad_unicode_encode_error(self):
- result = self.nsjail.python3("print(chr(56550))")
+ result = self.eval_file("print(chr(56550))")
self.assertEqual(result.returncode, 1)
self.assertIn("UnicodeEncodeError", result.stdout)
self.assertEqual(result.stderr, None)
def test_unicode_env_erase_escape_fails(self):
- result = self.nsjail.python3(
+ result = self.eval_file(
dedent(
"""
import os
@@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
self.assertEqual(result.stderr, None)
@@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Function not implemented", result.stdout)
self.assertEqual(result.stderr, None)
def test_numpy_import(self):
- result = self.nsjail.python3("import numpy")
+ result = self.eval_file("import numpy")
self.assertEqual(result.returncode, 0)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
@@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertLess(
result.stdout.find(stdout_msg),
result.stdout.find(stderr_msg),
@@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase):
def test_stdout_flood_results_in_graceful_sigterm(self):
code = "while True: print('abcdefghij')"
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 143)
def test_large_output_is_truncated(self):
@@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase):
self.assertEqual(output, chunk * expected_chunks)
def test_nsjail_args(self):
- args = ("foo", "bar")
- result = self.nsjail.python3("", nsjail_args=args)
+ args = ["foo", "bar"]
+ result = self.nsjail.python3((), nsjail_args=args)
end = result.args.index("--")
self.assertEqual(result.args[end - len(args) : end], args)
def test_py_args(self):
- args = ("-m", "timeit")
- result = self.nsjail.python3("", py_args=args)
-
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.args[-3:-1], args)
+ cases = [
+ # Normal args
+ (["-c", "print('hello')"], ["-c", "print('hello')"]),
+ # Leading empty strings should be removed
+ (["", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "", "-m", "timeit"], ["-m", "timeit"]),
+ # Non-leading empty strings should be preserved
+ (["-m", "timeit", ""], ["-m", "timeit", ""]),
+ ]
+
+ for args, expected in cases:
+ with self.subTest(args=args):
+ result = self.nsjail.python3(py_args=args)
+ idx = result.args.index("-BSqu")
+ self.assertEqual(result.args[idx + 1 :], expected)
+ self.assertEqual(result.returncode, 0)
class NsJailArgsTests(unittest.TestCase):
diff --git a/tests/test_snekio.py b/tests/test_snekio.py
new file mode 100644
index 0000000..8f04429
--- /dev/null
+++ b/tests/test_snekio.py
@@ -0,0 +1,58 @@
+from unittest import TestCase
+
+from snekbox import snekio
+from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError
+
+
+class SnekIOTests(TestCase):
+ def test_safe_path(self) -> None:
+ cases = [
+ ("", ""),
+ ("foo", "foo"),
+ ("foo/bar", "foo/bar"),
+ ("foo/bar.ext", "foo/bar.ext"),
+ ]
+
+ for path, expected in cases:
+ self.assertEqual(snekio.safe_path(path), expected)
+
+ def test_safe_path_raise(self):
+ cases = [
+ ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"),
+ ("/foo", IllegalPathError, "File path '/foo' must be relative"),
+ ]
+
+ for path, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ snekio.safe_path(path)
+ self.assertEqual(str(cm.exception), msg)
+
+ def test_file_from_dict(self):
+ cases = [
+ ({"path": "foo", "content": ""}, FileAttachment("foo", b"")),
+ ({"path": "foo"}, FileAttachment("foo", b"")),
+ ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")),
+ ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")),
+ ]
+
+ for data, expected in cases:
+ self.assertEqual(FileAttachment.from_dict(data), expected)
+
+ def test_file_from_dict_error(self):
+ cases = [
+ (
+ {"path": "foo", "content": "9"},
+ ParsingError,
+ "Invalid base64 encoding for file 'foo'",
+ ),
+ (
+ {"path": "var/a.txt", "content": "1="},
+ ParsingError,
+ "Invalid base64 encoding for file 'var/a.txt'",
+ ),
+ ]
+
+ for data, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ FileAttachment.from_dict(data)
+ self.assertEqual(str(cm.exception), msg)