aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/__init__.py4
-rw-r--r--tests/api/test_eval.py106
-rw-r--r--tests/test_filesystem.py143
-rw-r--r--tests/test_integration.py81
-rw-r--r--tests/test_main.py2
-rw-r--r--tests/test_memfs.py65
-rw-r--r--tests/test_nsjail.py170
-rw-r--r--tests/test_snekio.py58
8 files changed, 578 insertions, 51 deletions
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
index 0e6e422..5f20faf 100644
--- a/tests/api/__init__.py
+++ b/tests/api/__init__.py
@@ -1,10 +1,10 @@
import logging
-from subprocess import CompletedProcess
from unittest import mock
from falcon import testing
from snekbox.api import SnekAPI
+from snekbox.process import EvalResult
class SnekAPITestCase(testing.TestCase):
@@ -13,7 +13,7 @@ class SnekAPITestCase(testing.TestCase):
self.patcher = mock.patch("snekbox.api.snekapi.NsJail", autospec=True)
self.mock_nsjail = self.patcher.start()
- self.mock_nsjail.return_value.python3.return_value = CompletedProcess(
+ self.mock_nsjail.return_value.python3.return_value = EvalResult(
args=[], returncode=0, stdout="output", stderr="error"
)
self.addCleanup(self.patcher.stop)
diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py
index 976970e..37f90e7 100644
--- a/tests/api/test_eval.py
+++ b/tests/api/test_eval.py
@@ -5,12 +5,19 @@ class TestEvalResource(SnekAPITestCase):
PATH = "/eval"
def test_post_valid_200(self):
- body = {"input": "foo"}
- result = self.simulate_post(self.PATH, json=body)
-
- self.assertEqual(result.status_code, 200)
- self.assertEqual("output", result.json["stdout"])
- self.assertEqual(0, result.json["returncode"])
+ cases = [
+ {"args": ["-c", "print('output')"]},
+ {"input": "print('hello')"},
+ {"input": "print('hello')", "args": ["-c"]},
+ {"input": "print('hello')", "args": [""]},
+ {"input": "pass", "args": ["-m", "timeit"]},
+ ]
+ for body in cases:
+ with self.subTest():
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 200)
+ self.assertEqual("output", result.json["stdout"])
+ self.assertEqual(0, result.json["returncode"])
def test_post_invalid_schema_400(self):
body = {"stuff": "foo"}
@@ -20,27 +27,100 @@ class TestEvalResource(SnekAPITestCase):
expected = {
"title": "Request data failed validation",
- "description": "'input' is a required property",
+ "description": "{'stuff': 'foo'} is not valid under any of the given schemas",
}
self.assertEqual(expected, result.json)
def test_post_invalid_data_400(self):
- bodies = ({"input": 400}, {"input": "", "args": [400]})
-
- for body in bodies:
+ bodies = ({"args": 400}, {"args": [], "files": [215]})
+ expects = ["400 is not of type 'array'", "215 is not of type 'object'"]
+ for body, expected in zip(bodies, expects):
with self.subTest():
result = self.simulate_post(self.PATH, json=body)
self.assertEqual(result.status_code, 400)
- expected = {
+ expected_json = {
"title": "Request data failed validation",
- "description": "400 is not of type 'string'",
+ "description": expected,
+ }
+ self.assertEqual(expected_json, result.json)
+
+ def test_files_path(self):
+ """Normal paths should work with 200."""
+ test_paths = [
+ "file.txt",
+ "./0.jpg",
+ "path/to/file",
+ "folder/../hm",
+ "folder/./to/./somewhere",
+ "traversal/but/../not/beyond/../root",
+ r"backslash\\okay",
+ r"backslash\okay",
+ "numbers/0123456789",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 200)
+ self.assertEqual("output", result.json["stdout"])
+ self.assertEqual(0, result.json["returncode"])
+
+ def test_files_illegal_path_traversal(self):
+ """Traversal beyond root should be denied with 400 error."""
+ test_paths = [
+ "../secrets",
+ "../../dir",
+ "dir/../../secrets",
+ "dir/var/../../../file",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ expected = {
+ "title": "Request file is invalid",
+ "description": f"File path '{path}' may not traverse beyond root",
}
-
self.assertEqual(expected, result.json)
+ def test_files_illegal_path_absolute(self):
+ """Absolute file paths should 400-error at json schema validation stage."""
+ test_paths = [
+ "/",
+ "/etc",
+ "/etc/vars/secrets",
+ "/absolute",
+ "/file.bin",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ self.assertEqual("Request data failed validation", result.json["title"])
+ self.assertIn("does not match", result.json["description"])
+
+ def test_files_illegal_path_null_byte(self):
+ """Paths containing \0 should 400-error at json schema validation stage."""
+ test_paths = [
+ r"etc/passwd\0",
+ r"a\0b",
+ r"\0",
+ r"\\0",
+ r"var/\0/path",
+ ]
+ for path in test_paths:
+ with self.subTest(path=path):
+ body = {"args": ["test.py"], "files": [{"path": path}]}
+ result = self.simulate_post(self.PATH, json=body)
+ self.assertEqual(result.status_code, 400)
+ self.assertEqual("Request data failed validation", result.json["title"])
+ self.assertIn("does not match", result.json["description"])
+
def test_post_invalid_content_type_415(self):
body = "{'input': 'foo'}"
headers = {"Content-Type": "application/xml"}
diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py
new file mode 100644
index 0000000..e4d081f
--- /dev/null
+++ b/tests/test_filesystem.py
@@ -0,0 +1,143 @@
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager, suppress
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+from uuid import uuid4
+
+from snekbox.filesystem import UnmountFlags, mount, unmount
+
+
+class LibMountTests(TestCase):
+ temp_dir: TemporaryDirectory
+
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+
+ @contextmanager
+ def get_mount(self):
+ """Yield a valid mount point and unmount after context."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ yield path
+ finally:
+ with suppress(OSError):
+ unmount(path)
+
+ def test_mount(self):
+ """Test normal mounting."""
+ with self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ self.assertTrue(path.exists())
+ self.assertFalse(path.is_mount())
+ # Unmounting should not remove the original folder
+ self.assertTrue(path.exists())
+
+ def test_mount_errors(self):
+ """Test invalid mount errors."""
+ cases = [
+ (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"),
+ (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"),
+ (
+ dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"),
+ OSError,
+ "Invalid argument",
+ ),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ mount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_mount_duplicate(self):
+ """Test attempted mount after mounted."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ with self.assertRaises(OSError) as cm:
+ mount(source="", target=path, fs="tmpfs")
+ self.assertIn("already a mount point", str(cm.exception))
+ finally:
+ unmount(target=path)
+
+ def test_unmount_flags(self):
+ """Test unmount flags."""
+ flags = [
+ UnmountFlags.MNT_FORCE,
+ UnmountFlags.MNT_DETACH,
+ UnmountFlags.UMOUNT_NOFOLLOW,
+ ]
+ for flag in flags:
+ with self.subTest(flag=flag), self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ unmount(path, flag)
+ self.assertFalse(path.is_mount())
+
+ def test_unmount_flags_expire(self):
+ """Test unmount MNT_EXPIRE behavior."""
+ with self.get_mount() as path:
+ with self.assertRaises(BlockingIOError):
+ unmount(path, UnmountFlags.MNT_EXPIRE)
+
+ def test_unmount_errors(self):
+ """Test invalid unmount errors."""
+ cases = [
+ (dict(target="not/exist"), OSError, "is not a mount point"),
+ (dict(target=Path("not/exist")), OSError, "is not a mount point"),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ unmount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_unmount_invalid_args(self):
+ """Test invalid unmount invalid flag."""
+ with self.get_mount() as path:
+ with self.assertRaises(OSError) as cm:
+ unmount(path, 251)
+ self.assertIn("Invalid argument", str(cm.exception))
+
+ def test_threading(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)]
+
+ for path in paths:
+ path.mkdir()
+ self.assertFalse(path.is_mount())
+
+ try:
+ with ThreadPoolExecutor() as pool:
+ res = list(
+ pool.map(
+ mount,
+ [""] * len(paths),
+ paths,
+ ["tmpfs"] * len(paths),
+ )
+ )
+ self.assertEqual(len(res), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertTrue(path.is_mount())
+
+ unmounts = list(pool.map(unmount, paths))
+ self.assertEqual(len(unmounts), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertFalse(path.is_mount())
+ finally:
+ with suppress(OSError):
+ for path in paths:
+ unmount(path)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 7c5db2b..91b01e6 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -1,14 +1,25 @@
import json
import unittest
import urllib.request
+from base64 import b64encode
from multiprocessing.dummy import Pool
+from textwrap import dedent
from tests.gunicorn_utils import run_gunicorn
-def run_code_in_snekbox(code: str) -> tuple[str, int]:
- body = {"input": code}
- json_data = json.dumps(body).encode("utf-8")
+def b64encode_code(data: str):
+ data = dedent(data).strip()
+ return b64encode(data.encode()).decode("ascii")
+
+
+def snekbox_run_code(code: str) -> tuple[str, int]:
+ body = {"args": ["-c", code]}
+ return snekbox_request(body)
+
+
+def snekbox_request(content: dict) -> tuple[str, int]:
+ json_data = json.dumps(content).encode("utf-8")
req = urllib.request.Request("http://localhost:8060/eval")
req.add_header("Content-Type", "application/json; charset=utf-8")
@@ -34,9 +45,71 @@ class IntegrationTests(unittest.TestCase):
args = [code] * processes
with Pool(processes) as p:
- results = p.map(run_code_in_snekbox, args)
+ results = p.map(snekbox_run_code, args)
responses, statuses = zip(*results)
self.assertTrue(all(status == 200 for status in statuses))
self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses))
+
+ def test_eval(self):
+ """Test normal eval requests without files."""
+ with run_gunicorn():
+ cases = [
+ ({"input": "print('Hello')"}, "Hello\n"),
+ ({"args": ["-c", "print('abc12')"]}, "abc12\n"),
+ ]
+ for body, expected in cases:
+ with self.subTest(body=body):
+ response, status = snekbox_request(body)
+ self.assertEqual(status, 200)
+ self.assertEqual(json.loads(response)["stdout"], expected)
+
+ def test_files_send_receive(self):
+ """Test sending and receiving files to snekbox."""
+ with run_gunicorn():
+ request = {
+ "args": ["main.py"],
+ "files": [
+ {
+ "path": "main.py",
+ "content": b64encode_code(
+ """
+ from pathlib import Path
+ from mod import lib
+ print(lib.var)
+
+ with open('test.txt', 'w') as f:
+ f.write('test 1')
+
+ Path('dir').mkdir()
+ Path('dir/test2.txt').write_text('test 2')
+ """
+ ),
+ },
+ {"path": "mod/__init__.py"},
+ {"path": "mod/lib.py", "content": b64encode_code("var = 'hello'")},
+ ],
+ }
+
+ expected = {
+ "stdout": "hello\n",
+ "returncode": 0,
+ "files": [
+ {
+ "path": "dir/test2.txt",
+ "size": len("test 2"),
+ "content": b64encode_code("test 2"),
+ },
+ {
+ "path": "test.txt",
+ "size": len("test 1"),
+ "content": b64encode_code("test 1"),
+ },
+ ],
+ }
+
+ response, status = snekbox_request(request)
+
+ self.assertEqual(200, status)
+ self.assertEqual(expected, json.loads(response))
diff --git a/tests/test_main.py b/tests/test_main.py
index 77b3130..24c067c 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase):
@patch("sys.argv", ["", "import sys; sys.exit(22)"])
def test_main_exits_with_returncode(self):
- """Should exit with the subprocess's returncode if it's non-zero."""
+ """Should exit with the subprocess returncode if it's non-zero."""
with self.assertRaises(SystemExit) as cm:
snekbox_main.main()
diff --git a/tests/test_memfs.py b/tests/test_memfs.py
new file mode 100644
index 0000000..0555726
--- /dev/null
+++ b/tests/test_memfs.py
@@ -0,0 +1,65 @@
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import ExitStack
+from unittest import TestCase, mock
+from uuid import uuid4
+
+from snekbox.memfs import MemFS
+
+UUID_TEST = uuid4()
+
+
+class MemFSTests(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.logger = logging.getLogger("snekbox.memfs")
+ self.logger.setLevel(logging.WARNING)
+
+ @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST)
+ def test_assignment_thread_safe(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ # Concurrently create MemFS in threads, check only 1 can be created
+ # Others should result in RuntimeError
+ with ExitStack() as stack:
+ with ThreadPoolExecutor() as executor:
+ memfs: MemFS | None = None
+ # Each future uses enter_context to ensure __exit__ on test exception
+ futures = [
+ executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8)
+ ]
+ for future in futures:
+ # We should have exactly one result and all others RuntimeErrors
+ if err := future.exception():
+ self.assertIsInstance(err, RuntimeError)
+ else:
+ self.assertIsNone(memfs)
+ memfs = future.result()
+
+ # Original memfs should still exist afterwards
+ self.assertIsInstance(memfs, MemFS)
+ self.assertTrue(memfs.path.is_mount())
+
+ def test_cleanup(self):
+ """Test explicit cleanup."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ memfs.cleanup()
+ self.assertFalse(path.exists())
+
+ def test_context_cleanup(self):
+ """Context __exit__ should trigger cleanup."""
+ with MemFS(10) as memfs:
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ self.assertFalse(path.exists())
+
+ def test_implicit_cleanup(self):
+ """Test implicit _cleanup triggered by GC."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ # Catch the warning about implicit cleanup
+ with self.assertWarns(ResourceWarning):
+ del memfs
+ self.assertFalse(path.exists())
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 6f6e2a7..cad79f3 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -9,28 +9,40 @@ from itertools import product
from pathlib import Path
from textwrap import dedent
+from snekbox.filesystem import Size
from snekbox.nsjail import NsJail
+from snekbox.snekio import FileAttachment
class NsJailTests(unittest.TestCase):
def setUp(self):
super().setUp()
- self.nsjail = NsJail()
+ # Specify lower limits for unit tests to complete within time limits
+ self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB)
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)
+ def eval_code(self, code: str):
+ return self.nsjail.python3(["-c", code])
+
+ def eval_file(self, code: str, name: str = "test.py", **kwargs):
+ file = FileAttachment(name, code.encode())
+ return self.nsjail.python3([name], [file], **kwargs)
+
def test_print_returns_0(self):
- result = self.nsjail.python3("print('test')")
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.stdout, "test\n")
- self.assertEqual(result.stderr, None)
+ for fn in (self.eval_code, self.eval_file):
+ with self.subTest(fn.__name__):
+ result = fn("print('test')")
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "test\n")
+ self.assertEqual(result.stderr, None)
def test_timeout_returns_137(self):
code = "while True: pass"
with self.assertLogs(self.logger) as log:
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 137)
self.assertEqual(result.stdout, "")
@@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase):
# Add a kilobyte just to be safe.
code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}"
- result = self.nsjail.python3(code)
- self.assertEqual(result.returncode, 137)
+ result = self.eval_file(code)
self.assertEqual(result.stdout, "")
+ self.assertEqual(result.returncode, 137)
+ self.assertEqual(result.stderr, None)
+
+ def test_multi_files(self):
+ files = [
+ FileAttachment("main.py", "import lib; print(lib.x)".encode()),
+ FileAttachment("lib.py", "x = 'hello'".encode()),
+ ]
+
+ result = self.nsjail.python3(["main.py"], files)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
self.assertEqual(result.stderr, None)
def test_subprocess_resource_unavailable(self):
+ max_pids = self.nsjail.config.cgroup_pids_max
code = dedent(
- """
+ f"""
import subprocess
- # Max PIDs is 5.
- for _ in range(6):
+ # Should fail at n (max PIDs) since the caller python process counts as well
+ for _ in range({max_pids}):
print(subprocess.Popen(
[
'/usr/local/bin/python3',
@@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
+ # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n"
+ expected = "\n".join(map(str, range(2, max_pids + 1)))
+ self.assertIn(expected, result.stdout)
self.assertEqual(result.stderr, None)
def test_multiprocess_resource_limits(self):
@@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase):
"""
)
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
exit_codes = result.stdout.strip().split()
self.assertIn("-9", exit_codes)
@@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Read-only file system", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_write(self):
+ code = dedent(
+ """
+ from pathlib import Path
+ with open('test.txt', 'w') as f:
+ f.write('hello')
+ print(Path('test.txt').read_text())
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
+ self.assertEqual(result.stderr, None)
+
+ def test_write_exceed_space(self):
+ code = dedent(
+ f"""
+ size = {self.nsjail.memfs_instance_size} // 2048
+ with open('f.bin', 'wb') as f:
+ for i in range(size):
+ f.write(b'1' * 2048)
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 1)
+ self.assertIn("No space left on device", result.stdout)
+ self.assertEqual(result.stderr, None)
+
def test_forkbomb_resource_unavailable(self):
code = dedent(
"""
@@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_timeout(self):
+ code = dedent(
+ """
+ import os
+ data = "a" * 1024
+ size = 32 * 1024 * 1024
+
+ with open("file", "w") as f:
+ for _ in range((size // 1024) - 5):
+ f.write(data)
+
+ for i in range(100):
+ os.symlink("file", f"file{i}")
+ """
+ ).strip()
+
+ nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, None)
+ self.assertEqual(
+ result.stdout, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
+ self.assertEqual(result.stderr, None)
+
+ def test_file_write_error(self):
+ """Test errors during file write."""
+ result = self.nsjail.python3(
+ [""],
+ [
+ FileAttachment("dir/test.txt", b"abc"),
+ FileAttachment("dir", b"xyz"),
+ ],
+ )
+
+ self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.")
+ self.assertEqual(result.stderr, None)
+ self.assertEqual(result.returncode, None)
+
def test_sigsegv_returns_139(self): # In honour of Juan.
code = dedent(
"""
@@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 139)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
def test_null_byte_value_error(self):
- result = self.nsjail.python3("\0")
+ # This error only occurs with `-c` mode
+ result = self.eval_code("\0")
self.assertEqual(result.returncode, None)
self.assertEqual(result.stdout, "ValueError: embedded null byte")
self.assertEqual(result.stderr, None)
def test_print_bad_unicode_encode_error(self):
- result = self.nsjail.python3("print(chr(56550))")
+ result = self.eval_file("print(chr(56550))")
self.assertEqual(result.returncode, 1)
self.assertIn("UnicodeEncodeError", result.stdout)
self.assertEqual(result.stderr, None)
def test_unicode_env_erase_escape_fails(self):
- result = self.nsjail.python3(
+ result = self.eval_file(
dedent(
"""
import os
@@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
self.assertEqual(result.stderr, None)
@@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Function not implemented", result.stdout)
self.assertEqual(result.stderr, None)
def test_numpy_import(self):
- result = self.nsjail.python3("import numpy")
+ result = self.eval_file("import numpy")
self.assertEqual(result.returncode, 0)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
@@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertLess(
result.stdout.find(stdout_msg),
result.stdout.find(stderr_msg),
@@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase):
def test_stdout_flood_results_in_graceful_sigterm(self):
code = "while True: print('abcdefghij')"
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 143)
def test_large_output_is_truncated(self):
@@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase):
self.assertEqual(output, chunk * expected_chunks)
def test_nsjail_args(self):
- args = ("foo", "bar")
- result = self.nsjail.python3("", nsjail_args=args)
+ args = ["foo", "bar"]
+ result = self.nsjail.python3((), nsjail_args=args)
end = result.args.index("--")
self.assertEqual(result.args[end - len(args) : end], args)
def test_py_args(self):
- args = ("-m", "timeit")
- result = self.nsjail.python3("", py_args=args)
-
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.args[-3:-1], args)
+ cases = [
+ # Normal args
+ (["-c", "print('hello')"], ["-c", "print('hello')"]),
+ # Leading empty strings should be removed
+ (["", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "", "-m", "timeit"], ["-m", "timeit"]),
+ # Non-leading empty strings should be preserved
+ (["-m", "timeit", ""], ["-m", "timeit", ""]),
+ ]
+
+ for args, expected in cases:
+ with self.subTest(args=args):
+ result = self.nsjail.python3(py_args=args)
+ idx = result.args.index("-BSqu")
+ self.assertEqual(result.args[idx + 1 :], expected)
+ self.assertEqual(result.returncode, 0)
class NsJailArgsTests(unittest.TestCase):
diff --git a/tests/test_snekio.py b/tests/test_snekio.py
new file mode 100644
index 0000000..8f04429
--- /dev/null
+++ b/tests/test_snekio.py
@@ -0,0 +1,58 @@
+from unittest import TestCase
+
+from snekbox import snekio
+from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError
+
+
+class SnekIOTests(TestCase):
+ def test_safe_path(self) -> None:
+ cases = [
+ ("", ""),
+ ("foo", "foo"),
+ ("foo/bar", "foo/bar"),
+ ("foo/bar.ext", "foo/bar.ext"),
+ ]
+
+ for path, expected in cases:
+ self.assertEqual(snekio.safe_path(path), expected)
+
+ def test_safe_path_raise(self):
+ cases = [
+ ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"),
+ ("/foo", IllegalPathError, "File path '/foo' must be relative"),
+ ]
+
+ for path, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ snekio.safe_path(path)
+ self.assertEqual(str(cm.exception), msg)
+
+ def test_file_from_dict(self):
+ cases = [
+ ({"path": "foo", "content": ""}, FileAttachment("foo", b"")),
+ ({"path": "foo"}, FileAttachment("foo", b"")),
+ ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")),
+ ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")),
+ ]
+
+ for data, expected in cases:
+ self.assertEqual(FileAttachment.from_dict(data), expected)
+
+ def test_file_from_dict_error(self):
+ cases = [
+ (
+ {"path": "foo", "content": "9"},
+ ParsingError,
+ "Invalid base64 encoding for file 'foo'",
+ ),
+ (
+ {"path": "var/a.txt", "content": "1="},
+ ParsingError,
+ "Invalid base64 encoding for file 'var/a.txt'",
+ ),
+ ]
+
+ for data, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ FileAttachment.from_dict(data)
+ self.assertEqual(str(cm.exception), msg)