aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--snekbox/__main__.py4
-rw-r--r--snekbox/api/resources/eval.py42
-rw-r--r--snekbox/nsjail.py34
-rw-r--r--snekbox/snekio.py44
-rw-r--r--tests/api/test_eval.py17
-rw-r--r--tests/test_integration.py2
-rw-r--r--tests/test_main.py6
-rw-r--r--tests/test_nsjail.py60
8 files changed, 127 insertions, 82 deletions
diff --git a/snekbox/__main__.py b/snekbox/__main__.py
index 6cbd1ea..239f5d5 100644
--- a/snekbox/__main__.py
+++ b/snekbox/__main__.py
@@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace:
"nsjail_args", nargs="?", default=[], help="override configured NsJail options"
)
parser.add_argument(
- "py_args", nargs="?", default=[], help="arguments to pass to the Python process"
+ "py_args", nargs="?", default=["-c"], help="arguments to pass to the Python process"
)
# nsjail_args and py_args are just dummies for documentation purposes.
@@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
"""Evaluate Python code through NsJail."""
args = parse_args()
- result = NsJail().python3(args.code, nsjail_args=args.nsjail_args, py_args=args.py_args)
+ result = NsJail().python3(py_args=[*args.py_args, args.code], nsjail_args=args.nsjail_args)
print(result.stdout)
if result.returncode != 0:
diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py
index 38781ba..9244d7e 100644
--- a/snekbox/api/resources/eval.py
+++ b/snekbox/api/resources/eval.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import falcon
@@ -7,6 +9,8 @@ from snekbox.nsjail import NsJail
__all__ = ("EvalResource",)
+from snekbox.snekio import EvalRequestFile, FileParsingError
+
log = logging.getLogger(__name__)
@@ -23,10 +27,17 @@ class EvalResource:
REQ_SCHEMA = {
"type": "object",
"properties": {
- "input": {"type": "string"},
"args": {"type": "array", "items": {"type": "string"}},
+ "files": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {"name": {"type": "string"}, "content": {"type": "string"}},
+ "required": ["name"],
+ },
+ },
},
- "required": ["input"],
+ "required": ["args"],
}
def __init__(self, nsjail: NsJail):
@@ -51,14 +62,23 @@ class EvalResource:
Request body:
>>> {
- ... "input": "[i for i in range(1000)]",
- ... "args": ["-m", "timeit"] # This is optional
+ ... "args": ["-c", "print('Hello')"]
+ ... }
+
+ >>> {
+ ... "args": ["main.py"],
+ ... "files": [
+ ... {
+ ... "name": "main.py",
+ ... "content": "print(1)"
+ ... }
+ ... ]
... }
Response format:
>>> {
- ... "stdout": "10000 loops, best of 5: 23.8 usec per loop\n",
+ ... "stdout": "10000 loops, best of 5: 23.8 usec per loop",
... "returncode": 0,
... "attachments": [
... {
@@ -76,15 +96,17 @@ class EvalResource:
- 200
Successful evaluation; not indicative that the input code itself works
- 400
- Input's JSON schema is invalid
+ Input JSON schema is invalid
- 415
Unsupported content type; only application/JSON is supported
"""
- code = req.media["input"]
- args = req.media.get("args", ("",))
-
try:
- result = self.nsjail.python3(code, py_args=args)
+ result = self.nsjail.python3(
+ py_args=req.media["args"],
+ files=[EvalRequestFile.from_dict(file) for file in req.media.get("files", [])],
+ )
+ except FileParsingError as e:
+ raise falcon.HTTPBadRequest("Invalid file in request", str(e))
except Exception:
log.exception("An exception occurred while trying to process the request")
raise falcon.HTTPInternalServerError
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index 8cd32c7..723bd30 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -2,7 +2,6 @@ import logging
import re
import subprocess
import sys
-import textwrap
from tempfile import NamedTemporaryFile
from typing import Iterable
@@ -15,7 +14,7 @@ from snekbox.memfs import MemFS
__all__ = ("NsJail",)
from snekbox.process import EvalResult
-from snekbox.snekio import AttachmentError
+from snekbox.snekio import AttachmentError, EvalRequestFile
log = logging.getLogger(__name__)
@@ -139,11 +138,10 @@ class NsJail:
def python3(
self,
- code: str,
+ py_args: Iterable[str],
+ files: Iterable[EvalRequestFile] = (),
*,
nsjail_args: Iterable[str] = (),
- py_args: Iterable[str] = (),
- use_file: bool | None = None,
) -> EvalResult:
"""
Execute Python 3 code in an isolated environment and return the completed process.
@@ -195,30 +193,18 @@ class NsJail:
"--",
self.config.exec_bin.path,
*self.config.exec_bin.arg,
+ # Filter out empty strings (causes issues with python cli)
+ *(arg for arg in py_args if arg),
]
- # Filter out empty strings (causes issues with python cli)
- args.extend(s for s in py_args if s)
-
- c_arg = "c" in "".join(py_args)
-
- # Override for `timeit`
- if "timeit" in py_args:
- use_file = False
-
- match (use_file, c_arg):
- case (True, _) | (None, False):
- args.append("main.py")
- # Write the code to a file
- code_path = fs.home / "main.py"
- code_path.write_text(code)
- log.info(f"Created code file at [{code_path!r}].")
- case _:
- args.append(code)
+ # Write files if any
+ for file in files:
+ file.save_to(fs.home)
+ log.info(f"Created file at [{(fs.home / file.name)!r}].")
msg = "Executing code..."
if DEBUG:
- msg = f"{msg[:-3]}:\n{textwrap.indent(code, ' ')}\nWith the arguments {args}."
+ msg = f"{msg[:-3]}: With the arguments {args}."
log.info(msg)
try:
diff --git a/snekbox/snekio.py b/snekbox/snekio.py
index 3074041..26acd04 100644
--- a/snekbox/snekio.py
+++ b/snekbox/snekio.py
@@ -1,11 +1,12 @@
from __future__ import annotations
-import mimetypes
import zlib
from base64 import b64encode
from dataclasses import dataclass
from pathlib import Path
+RequestType = dict[str, str | bool | list[str | dict[str, str]]]
+
def sizeof_fmt(num: int, suffix: str = "B") -> str:
"""Return a human-readable file size."""
@@ -20,6 +21,40 @@ class AttachmentError(ValueError):
"""Raised when an attachment is invalid."""
+class FileParsingError(ValueError):
+ """Raised when a request file cannot be parsed."""
+
+
+@dataclass
+class EvalRequestFile:
+ """A file sent in an eval request."""
+
+ name: str
+ content: str
+
+ @classmethod
+ def from_dict(cls, data: dict[str, str]) -> EvalRequestFile:
+ """Convert a dict to a str attachment."""
+ name = data["name"]
+ path = Path(name)
+ parts = path.parts
+
+ if path.is_absolute() or set(parts[0]) & {"\\", "/"}:
+ raise FileParsingError(f"File path '{name}' must be relative")
+
+ if any(set(part) == {"."} for part in parts):
+ raise FileParsingError(f"File path '{name}' may not use traversal ('..')")
+
+ return cls(name, data.get("content", ""))
+
+ def save_to(self, directory: Path) -> None:
+ """Save the attachment to a path directory."""
+ file = Path(directory, self.name)
+ # Create directories if they don't exist
+ file.parent.mkdir(parents=True, exist_ok=True)
+ file.write_text(self.content)
+
+
@dataclass
class FileAttachment:
"""A file attachment."""
@@ -39,11 +74,6 @@ class FileAttachment:
return cls(file.name, file.read_bytes())
@property
- def mime(self) -> str:
- """MIME type of the attachment."""
- return mimetypes.guess_type(self.name)[0]
-
- @property
def size(self) -> int:
"""Size of the attachment."""
return len(self.content)
@@ -54,8 +84,6 @@ class FileAttachment:
content = b64encode(cmp).decode("ascii")
return {
"name": self.name,
- "mime": self.mime,
"size": self.size,
- "compression": "zlib",
"content": content,
}
diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py
index 976970e..caa848e 100644
--- a/tests/api/test_eval.py
+++ b/tests/api/test_eval.py
@@ -5,7 +5,7 @@ class TestEvalResource(SnekAPITestCase):
PATH = "/eval"
def test_post_valid_200(self):
- body = {"input": "foo"}
+ body = {"args": ["-c", "print('output')"]}
result = self.simulate_post(self.PATH, json=body)
self.assertEqual(result.status_code, 200)
@@ -20,26 +20,25 @@ class TestEvalResource(SnekAPITestCase):
expected = {
"title": "Request data failed validation",
- "description": "'input' is a required property",
+ "description": "'args' is a required property",
}
self.assertEqual(expected, result.json)
def test_post_invalid_data_400(self):
- bodies = ({"input": 400}, {"input": "", "args": [400]})
-
- for body in bodies:
+ bodies = ({"args": 400}, {"args": [], "files": [215]})
+ expects = ["400 is not of type 'array'", "215 is not of type 'object'"]
+ for body, expected in zip(bodies, expects):
with self.subTest():
result = self.simulate_post(self.PATH, json=body)
self.assertEqual(result.status_code, 400)
- expected = {
+ expected_json = {
"title": "Request data failed validation",
- "description": "400 is not of type 'string'",
+ "description": expected,
}
-
- self.assertEqual(expected, result.json)
+ self.assertEqual(expected_json, result.json)
def test_post_invalid_content_type_415(self):
body = "{'input': 'foo'}"
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 7c5db2b..eba5e60 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -7,7 +7,7 @@ from tests.gunicorn_utils import run_gunicorn
def run_code_in_snekbox(code: str) -> tuple[str, int]:
- body = {"input": code}
+ body = {"args": ["-c", code]}
json_data = json.dumps(body).encode("utf-8")
req = urllib.request.Request("http://localhost:8060/eval")
diff --git a/tests/test_main.py b/tests/test_main.py
index 1e6cbc5..24c067c 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -12,10 +12,10 @@ import snekbox.__main__ as snekbox_main
class ArgParseTests(unittest.TestCase):
def test_parse_args(self):
subtests = (
- (["", "code"], Namespace(code="code", nsjail_args=[], py_args=[])),
+ (["", "code"], Namespace(code="code", nsjail_args=[], py_args=["-c"])),
(
["", "code", "--time_limit", "0"],
- Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=[]),
+ Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=["-c"]),
),
(
["", "code", "---", "-m", "timeit"],
@@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase):
@patch("sys.argv", ["", "import sys; sys.exit(22)"])
def test_main_exits_with_returncode(self):
- """Should exit with the subprocess's returncode if it's non-zero."""
+ """Should exit with the subprocess returncode if it's non-zero."""
with self.assertRaises(SystemExit) as cm:
snekbox_main.main()
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index cea96bd..324b88a 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -10,6 +10,7 @@ from pathlib import Path
from textwrap import dedent
from snekbox.nsjail import NsJail
+from snekbox.snekio import EvalRequestFile
class NsJailTests(unittest.TestCase):
@@ -20,17 +21,26 @@ class NsJailTests(unittest.TestCase):
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)
+ def eval_code(self, code: str):
+ return self.nsjail.python3(["-c", code])
+
+ def eval_file(self, code: str, name: str = "test.py"):
+ file = EvalRequestFile(name, code)
+ return self.nsjail.python3([name], [file])
+
def test_print_returns_0(self):
- result = self.nsjail.python3("print('test')")
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.stdout, "test\n")
- self.assertEqual(result.stderr, None)
+ for fn in (self.eval_code, self.eval_file):
+ with self.subTest(fn.__name__):
+ result = fn("print('test')")
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "test\n")
+ self.assertEqual(result.stderr, None)
def test_timeout_returns_137(self):
code = "while True: pass"
with self.assertLogs(self.logger) as log:
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 137)
self.assertEqual(result.stdout, "")
@@ -41,7 +51,7 @@ class NsJailTests(unittest.TestCase):
# Add a kilobyte just to be safe.
code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}"
- result = self.nsjail.python3(code, py_args=("-c",))
+ result = self.eval_file(code)
self.assertEqual(result.stdout, "")
self.assertEqual(result.returncode, 137)
self.assertEqual(result.stderr, None)
@@ -64,7 +74,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
# Also expect n-1 processes to be opened
@@ -96,7 +106,7 @@ class NsJailTests(unittest.TestCase):
"""
)
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
exit_codes = result.stdout.strip().split()
self.assertIn("-9", exit_codes)
@@ -112,7 +122,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Read-only file system", result.stdout)
self.assertEqual(result.stderr, None)
@@ -127,7 +137,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 0)
self.assertEqual(result.stdout, "hello\n")
self.assertEqual(result.stderr, None)
@@ -145,7 +155,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No space left on device", result.stdout)
self.assertEqual(result.stderr, None)
@@ -165,7 +175,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 0)
self.assertEqual(
result.stdout,
@@ -184,7 +194,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
self.assertEqual(result.stderr, None)
@@ -197,7 +207,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 139)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
@@ -205,19 +215,19 @@ class NsJailTests(unittest.TestCase):
def test_null_byte_value_error(self):
# This error does not occur without -c, where it
# would be a normal SyntaxError.
- result = self.nsjail.python3("\0", py_args=("-c",))
+ result = self.nsjail.python3(["-c", "\0"])
self.assertEqual(result.returncode, None)
self.assertEqual(result.stdout, "ValueError: embedded null byte")
self.assertEqual(result.stderr, None)
def test_print_bad_unicode_encode_error(self):
- result = self.nsjail.python3("print(chr(56550))")
+ result = self.eval_file("print(chr(56550))")
self.assertEqual(result.returncode, 1)
self.assertIn("UnicodeEncodeError", result.stdout)
self.assertEqual(result.stderr, None)
def test_unicode_env_erase_escape_fails(self):
- result = self.nsjail.python3(
+ result = self.eval_file(
dedent(
"""
import os
@@ -271,7 +281,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
self.assertEqual(result.stderr, None)
@@ -287,13 +297,13 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Function not implemented", result.stdout)
self.assertEqual(result.stderr, None)
def test_numpy_import(self):
- result = self.nsjail.python3("import numpy")
+ result = self.eval_file("import numpy")
self.assertEqual(result.returncode, 0)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
@@ -308,7 +318,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertLess(
result.stdout.find(stdout_msg),
result.stdout.find(stderr_msg),
@@ -319,7 +329,7 @@ class NsJailTests(unittest.TestCase):
def test_stdout_flood_results_in_graceful_sigterm(self):
code = "while True: print('abcdefghij')"
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 143)
def test_large_output_is_truncated(self):
@@ -337,17 +347,17 @@ class NsJailTests(unittest.TestCase):
def test_nsjail_args(self):
args = ["foo", "bar"]
- result = self.nsjail.python3("", nsjail_args=args)
+ result = self.nsjail.python3((), nsjail_args=args)
end = result.args.index("--")
self.assertEqual(result.args[end - len(args) : end], args)
def test_py_args(self):
args = ["-m", "timeit"]
- result = self.nsjail.python3("", py_args=args)
+ result = self.nsjail.python3(args)
self.assertEqual(result.returncode, 0)
- self.assertEqual(result.args[-3:-1], args)
+ self.assertEqual(result.args[-2:], args)
class NsJailArgsTests(unittest.TestCase):