diff options
-rw-r--r-- | snekbox/__main__.py | 4 | ||||
-rw-r--r-- | snekbox/api/resources/eval.py | 42 | ||||
-rw-r--r-- | snekbox/nsjail.py | 34 | ||||
-rw-r--r-- | snekbox/snekio.py | 44 | ||||
-rw-r--r-- | tests/api/test_eval.py | 17 | ||||
-rw-r--r-- | tests/test_integration.py | 2 | ||||
-rw-r--r-- | tests/test_main.py | 6 | ||||
-rw-r--r-- | tests/test_nsjail.py | 60 |
8 files changed, 127 insertions, 82 deletions
diff --git a/snekbox/__main__.py b/snekbox/__main__.py index 6cbd1ea..239f5d5 100644 --- a/snekbox/__main__.py +++ b/snekbox/__main__.py @@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace: "nsjail_args", nargs="?", default=[], help="override configured NsJail options" ) parser.add_argument( - "py_args", nargs="?", default=[], help="arguments to pass to the Python process" + "py_args", nargs="?", default=["-c"], help="arguments to pass to the Python process" ) # nsjail_args and py_args are just dummies for documentation purposes. @@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: """Evaluate Python code through NsJail.""" args = parse_args() - result = NsJail().python3(args.code, nsjail_args=args.nsjail_args, py_args=args.py_args) + result = NsJail().python3(py_args=[*args.py_args, args.code], nsjail_args=args.nsjail_args) print(result.stdout) if result.returncode != 0: diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py index 38781ba..9244d7e 100644 --- a/snekbox/api/resources/eval.py +++ b/snekbox/api/resources/eval.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import falcon @@ -7,6 +9,8 @@ from snekbox.nsjail import NsJail __all__ = ("EvalResource",) +from snekbox.snekio import EvalRequestFile, FileParsingError + log = logging.getLogger(__name__) @@ -23,10 +27,17 @@ class EvalResource: REQ_SCHEMA = { "type": "object", "properties": { - "input": {"type": "string"}, "args": {"type": "array", "items": {"type": "string"}}, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": {"name": {"type": "string"}, "content": {"type": "string"}}, + "required": ["name"], + }, + }, }, - "required": ["input"], + "required": ["args"], } def __init__(self, nsjail: NsJail): @@ -51,14 +62,23 @@ class EvalResource: Request body: >>> { - ... "input": "[i for i in range(1000)]", - ... "args": ["-m", "timeit"] # This is optional + ... "args": ["-c", "print('Hello')"] + ... } + + >>> { + ... "args": ["main.py"], + ... "files": [ + ... { + ... "name": "main.py", + ... "content": "print(1)" + ... } + ... ] ... } Response format: >>> { - ... "stdout": "10000 loops, best of 5: 23.8 usec per loop\n", + ... "stdout": "10000 loops, best of 5: 23.8 usec per loop", ... "returncode": 0, ... "attachments": [ ... { @@ -76,15 +96,17 @@ class EvalResource: - 200 Successful evaluation; not indicative that the input code itself works - 400 - Input's JSON schema is invalid + Input JSON schema is invalid - 415 Unsupported content type; only application/JSON is supported """ - code = req.media["input"] - args = req.media.get("args", ("",)) - try: - result = self.nsjail.python3(code, py_args=args) + result = self.nsjail.python3( + py_args=req.media["args"], + files=[EvalRequestFile.from_dict(file) for file in req.media.get("files", [])], + ) + except FileParsingError as e: + raise falcon.HTTPBadRequest("Invalid file in request", str(e)) except Exception: log.exception("An exception occurred while trying to process the request") raise falcon.HTTPInternalServerError diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 8cd32c7..723bd30 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -2,7 +2,6 @@ import logging import re import subprocess import sys -import textwrap from tempfile import NamedTemporaryFile from typing import Iterable @@ -15,7 +14,7 @@ from snekbox.memfs import MemFS __all__ = ("NsJail",) from snekbox.process import EvalResult -from snekbox.snekio import AttachmentError +from snekbox.snekio import AttachmentError, EvalRequestFile log = logging.getLogger(__name__) @@ -139,11 +138,10 @@ class NsJail: def python3( self, - code: str, + py_args: Iterable[str], + files: Iterable[EvalRequestFile] = (), *, nsjail_args: Iterable[str] = (), - py_args: Iterable[str] = (), - use_file: bool | None = None, ) -> EvalResult: """ Execute Python 3 code in an isolated environment and return the completed process. @@ -195,30 +193,18 @@ class NsJail: "--", self.config.exec_bin.path, *self.config.exec_bin.arg, + # Filter out empty strings (causes issues with python cli) + *(arg for arg in py_args if arg), ] - # Filter out empty strings (causes issues with python cli) - args.extend(s for s in py_args if s) - - c_arg = "c" in "".join(py_args) - - # Override for `timeit` - if "timeit" in py_args: - use_file = False - - match (use_file, c_arg): - case (True, _) | (None, False): - args.append("main.py") - # Write the code to a file - code_path = fs.home / "main.py" - code_path.write_text(code) - log.info(f"Created code file at [{code_path!r}].") - case _: - args.append(code) + # Write files if any + for file in files: + file.save_to(fs.home) + log.info(f"Created file at [{(fs.home / file.name)!r}].") msg = "Executing code..." if DEBUG: - msg = f"{msg[:-3]}:\n{textwrap.indent(code, ' ')}\nWith the arguments {args}." + msg = f"{msg[:-3]}: With the arguments {args}." log.info(msg) try: diff --git a/snekbox/snekio.py b/snekbox/snekio.py index 3074041..26acd04 100644 --- a/snekbox/snekio.py +++ b/snekbox/snekio.py @@ -1,11 +1,12 @@ from __future__ import annotations -import mimetypes import zlib from base64 import b64encode from dataclasses import dataclass from pathlib import Path +RequestType = dict[str, str | bool | list[str | dict[str, str]]] + def sizeof_fmt(num: int, suffix: str = "B") -> str: """Return a human-readable file size.""" @@ -20,6 +21,40 @@ class AttachmentError(ValueError): """Raised when an attachment is invalid.""" +class FileParsingError(ValueError): + """Raised when a request file cannot be parsed.""" + + +@dataclass +class EvalRequestFile: + """A file sent in an eval request.""" + + name: str + content: str + + @classmethod + def from_dict(cls, data: dict[str, str]) -> EvalRequestFile: + """Convert a dict to a str attachment.""" + name = data["name"] + path = Path(name) + parts = path.parts + + if path.is_absolute() or set(parts[0]) & {"\\", "/"}: + raise FileParsingError(f"File path '{name}' must be relative") + + if any(set(part) == {"."} for part in parts): + raise FileParsingError(f"File path '{name}' may not use traversal ('..')") + + return cls(name, data.get("content", "")) + + def save_to(self, directory: Path) -> None: + """Save the attachment to a path directory.""" + file = Path(directory, self.name) + # Create directories if they don't exist + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text(self.content) + + @dataclass class FileAttachment: """A file attachment.""" @@ -39,11 +74,6 @@ class FileAttachment: return cls(file.name, file.read_bytes()) @property - def mime(self) -> str: - """MIME type of the attachment.""" - return mimetypes.guess_type(self.name)[0] - - @property def size(self) -> int: """Size of the attachment.""" return len(self.content) @@ -54,8 +84,6 @@ class FileAttachment: content = b64encode(cmp).decode("ascii") return { "name": self.name, - "mime": self.mime, "size": self.size, - "compression": "zlib", "content": content, } diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py index 976970e..caa848e 100644 --- a/tests/api/test_eval.py +++ b/tests/api/test_eval.py @@ -5,7 +5,7 @@ class TestEvalResource(SnekAPITestCase): PATH = "/eval" def test_post_valid_200(self): - body = {"input": "foo"} + body = {"args": ["-c", "print('output')"]} result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 200) @@ -20,26 +20,25 @@ class TestEvalResource(SnekAPITestCase): expected = { "title": "Request data failed validation", - "description": "'input' is a required property", + "description": "'args' is a required property", } self.assertEqual(expected, result.json) def test_post_invalid_data_400(self): - bodies = ({"input": 400}, {"input": "", "args": [400]}) - - for body in bodies: + bodies = ({"args": 400}, {"args": [], "files": [215]}) + expects = ["400 is not of type 'array'", "215 is not of type 'object'"] + for body, expected in zip(bodies, expects): with self.subTest(): result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 400) - expected = { + expected_json = { "title": "Request data failed validation", - "description": "400 is not of type 'string'", + "description": expected, } - - self.assertEqual(expected, result.json) + self.assertEqual(expected_json, result.json) def test_post_invalid_content_type_415(self): body = "{'input': 'foo'}" diff --git a/tests/test_integration.py b/tests/test_integration.py index 7c5db2b..eba5e60 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,7 +7,7 @@ from tests.gunicorn_utils import run_gunicorn def run_code_in_snekbox(code: str) -> tuple[str, int]: - body = {"input": code} + body = {"args": ["-c", code]} json_data = json.dumps(body).encode("utf-8") req = urllib.request.Request("http://localhost:8060/eval") diff --git a/tests/test_main.py b/tests/test_main.py index 1e6cbc5..24c067c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,10 +12,10 @@ import snekbox.__main__ as snekbox_main class ArgParseTests(unittest.TestCase): def test_parse_args(self): subtests = ( - (["", "code"], Namespace(code="code", nsjail_args=[], py_args=[])), + (["", "code"], Namespace(code="code", nsjail_args=[], py_args=["-c"])), ( ["", "code", "--time_limit", "0"], - Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=[]), + Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=["-c"]), ), ( ["", "code", "---", "-m", "timeit"], @@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase): @patch("sys.argv", ["", "import sys; sys.exit(22)"]) def test_main_exits_with_returncode(self): - """Should exit with the subprocess's returncode if it's non-zero.""" + """Should exit with the subprocess returncode if it's non-zero.""" with self.assertRaises(SystemExit) as cm: snekbox_main.main() diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index cea96bd..324b88a 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -10,6 +10,7 @@ from pathlib import Path from textwrap import dedent from snekbox.nsjail import NsJail +from snekbox.snekio import EvalRequestFile class NsJailTests(unittest.TestCase): @@ -20,17 +21,26 @@ class NsJailTests(unittest.TestCase): self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) + def eval_code(self, code: str): + return self.nsjail.python3(["-c", code]) + + def eval_file(self, code: str, name: str = "test.py"): + file = EvalRequestFile(name, code) + return self.nsjail.python3([name], [file]) + def test_print_returns_0(self): - result = self.nsjail.python3("print('test')") - self.assertEqual(result.returncode, 0) - self.assertEqual(result.stdout, "test\n") - self.assertEqual(result.stderr, None) + for fn in (self.eval_code, self.eval_file): + with self.subTest(fn.__name__): + result = fn("print('test')") + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "test\n") + self.assertEqual(result.stderr, None) def test_timeout_returns_137(self): code = "while True: pass" with self.assertLogs(self.logger) as log: - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 137) self.assertEqual(result.stdout, "") @@ -41,7 +51,7 @@ class NsJailTests(unittest.TestCase): # Add a kilobyte just to be safe. code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" - result = self.nsjail.python3(code, py_args=("-c",)) + result = self.eval_file(code) self.assertEqual(result.stdout, "") self.assertEqual(result.returncode, 137) self.assertEqual(result.stderr, None) @@ -64,7 +74,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) # Also expect n-1 processes to be opened @@ -96,7 +106,7 @@ class NsJailTests(unittest.TestCase): """ ) - result = self.nsjail.python3(code) + result = self.eval_file(code) exit_codes = result.stdout.strip().split() self.assertIn("-9", exit_codes) @@ -112,7 +122,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Read-only file system", result.stdout) self.assertEqual(result.stderr, None) @@ -127,7 +137,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "hello\n") self.assertEqual(result.stderr, None) @@ -145,7 +155,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No space left on device", result.stdout) self.assertEqual(result.stderr, None) @@ -165,7 +175,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 0) self.assertEqual( result.stdout, @@ -184,7 +194,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) @@ -197,7 +207,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 139) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -205,19 +215,19 @@ class NsJailTests(unittest.TestCase): def test_null_byte_value_error(self): # This error does not occur without -c, where it # would be a normal SyntaxError. - result = self.nsjail.python3("\0", py_args=("-c",)) + result = self.nsjail.python3(["-c", "\0"]) self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "ValueError: embedded null byte") self.assertEqual(result.stderr, None) def test_print_bad_unicode_encode_error(self): - result = self.nsjail.python3("print(chr(56550))") + result = self.eval_file("print(chr(56550))") self.assertEqual(result.returncode, 1) self.assertIn("UnicodeEncodeError", result.stdout) self.assertEqual(result.stderr, None) def test_unicode_env_erase_escape_fails(self): - result = self.nsjail.python3( + result = self.eval_file( dedent( """ import os @@ -271,7 +281,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No such file or directory", result.stdout) self.assertEqual(result.stderr, None) @@ -287,13 +297,13 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Function not implemented", result.stdout) self.assertEqual(result.stderr, None) def test_numpy_import(self): - result = self.nsjail.python3("import numpy") + result = self.eval_file("import numpy") self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -308,7 +318,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertLess( result.stdout.find(stdout_msg), result.stdout.find(stderr_msg), @@ -319,7 +329,7 @@ class NsJailTests(unittest.TestCase): def test_stdout_flood_results_in_graceful_sigterm(self): code = "while True: print('abcdefghij')" - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): @@ -337,17 +347,17 @@ class NsJailTests(unittest.TestCase): def test_nsjail_args(self): args = ["foo", "bar"] - result = self.nsjail.python3("", nsjail_args=args) + result = self.nsjail.python3((), nsjail_args=args) end = result.args.index("--") self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): args = ["-m", "timeit"] - result = self.nsjail.python3("", py_args=args) + result = self.nsjail.python3(args) self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-3:-1], args) + self.assertEqual(result.args[-2:], args) class NsJailArgsTests(unittest.TestCase): |