diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/__init__.py | 4 | ||||
-rw-r--r-- | tests/api/test_eval.py | 106 | ||||
-rw-r--r-- | tests/test_filesystem.py | 143 | ||||
-rw-r--r-- | tests/test_integration.py | 81 | ||||
-rw-r--r-- | tests/test_main.py | 2 | ||||
-rw-r--r-- | tests/test_memfs.py | 65 | ||||
-rw-r--r-- | tests/test_nsjail.py | 170 | ||||
-rw-r--r-- | tests/test_snekio.py | 58 |
8 files changed, 578 insertions, 51 deletions
diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 0e6e422..5f20faf 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -1,10 +1,10 @@ import logging -from subprocess import CompletedProcess from unittest import mock from falcon import testing from snekbox.api import SnekAPI +from snekbox.process import EvalResult class SnekAPITestCase(testing.TestCase): @@ -13,7 +13,7 @@ class SnekAPITestCase(testing.TestCase): self.patcher = mock.patch("snekbox.api.snekapi.NsJail", autospec=True) self.mock_nsjail = self.patcher.start() - self.mock_nsjail.return_value.python3.return_value = CompletedProcess( + self.mock_nsjail.return_value.python3.return_value = EvalResult( args=[], returncode=0, stdout="output", stderr="error" ) self.addCleanup(self.patcher.stop) diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py index 976970e..37f90e7 100644 --- a/tests/api/test_eval.py +++ b/tests/api/test_eval.py @@ -5,12 +5,19 @@ class TestEvalResource(SnekAPITestCase): PATH = "/eval" def test_post_valid_200(self): - body = {"input": "foo"} - result = self.simulate_post(self.PATH, json=body) - - self.assertEqual(result.status_code, 200) - self.assertEqual("output", result.json["stdout"]) - self.assertEqual(0, result.json["returncode"]) + cases = [ + {"args": ["-c", "print('output')"]}, + {"input": "print('hello')"}, + {"input": "print('hello')", "args": ["-c"]}, + {"input": "print('hello')", "args": [""]}, + {"input": "pass", "args": ["-m", "timeit"]}, + ] + for body in cases: + with self.subTest(): + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 200) + self.assertEqual("output", result.json["stdout"]) + self.assertEqual(0, result.json["returncode"]) def test_post_invalid_schema_400(self): body = {"stuff": "foo"} @@ -20,27 +27,100 @@ class TestEvalResource(SnekAPITestCase): expected = { "title": "Request data failed validation", - "description": "'input' is a required property", + "description": "{'stuff': 'foo'} is not valid under any of the given schemas", } self.assertEqual(expected, result.json) def test_post_invalid_data_400(self): - bodies = ({"input": 400}, {"input": "", "args": [400]}) - - for body in bodies: + bodies = ({"args": 400}, {"args": [], "files": [215]}) + expects = ["400 is not of type 'array'", "215 is not of type 'object'"] + for body, expected in zip(bodies, expects): with self.subTest(): result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 400) - expected = { + expected_json = { "title": "Request data failed validation", - "description": "400 is not of type 'string'", + "description": expected, + } + self.assertEqual(expected_json, result.json) + + def test_files_path(self): + """Normal paths should work with 200.""" + test_paths = [ + "file.txt", + "./0.jpg", + "path/to/file", + "folder/../hm", + "folder/./to/./somewhere", + "traversal/but/../not/beyond/../root", + r"backslash\\okay", + r"backslash\okay", + "numbers/0123456789", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 200) + self.assertEqual("output", result.json["stdout"]) + self.assertEqual(0, result.json["returncode"]) + + def test_files_illegal_path_traversal(self): + """Traversal beyond root should be denied with 400 error.""" + test_paths = [ + "../secrets", + "../../dir", + "dir/../../secrets", + "dir/var/../../../file", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + expected = { + "title": "Request file is invalid", + "description": f"File path '{path}' may not traverse beyond root", } - self.assertEqual(expected, result.json) + def test_files_illegal_path_absolute(self): + """Absolute file paths should 400-error at json schema validation stage.""" + test_paths = [ + "/", + "/etc", + "/etc/vars/secrets", + "/absolute", + "/file.bin", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + self.assertEqual("Request data failed validation", result.json["title"]) + self.assertIn("does not match", result.json["description"]) + + def test_files_illegal_path_null_byte(self): + """Paths containing \0 should 400-error at json schema validation stage.""" + test_paths = [ + r"etc/passwd\0", + r"a\0b", + r"\0", + r"\\0", + r"var/\0/path", + ] + for path in test_paths: + with self.subTest(path=path): + body = {"args": ["test.py"], "files": [{"path": path}]} + result = self.simulate_post(self.PATH, json=body) + self.assertEqual(result.status_code, 400) + self.assertEqual("Request data failed validation", result.json["title"]) + self.assertIn("does not match", result.json["description"]) + def test_post_invalid_content_type_415(self): body = "{'input': 'foo'}" headers = {"Content-Type": "application/xml"} diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000..e4d081f --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,143 @@ +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, suppress +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase +from uuid import uuid4 + +from snekbox.filesystem import UnmountFlags, mount, unmount + + +class LibMountTests(TestCase): + temp_dir: TemporaryDirectory + + @classmethod + def setUpClass(cls): + cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests") + + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + + @contextmanager + def get_mount(self): + """Yield a valid mount point and unmount after context.""" + path = Path(self.temp_dir.name, str(uuid4())) + path.mkdir() + try: + mount(source="", target=path, fs="tmpfs") + yield path + finally: + with suppress(OSError): + unmount(path) + + def test_mount(self): + """Test normal mounting.""" + with self.get_mount() as path: + self.assertTrue(path.is_mount()) + self.assertTrue(path.exists()) + self.assertFalse(path.is_mount()) + # Unmounting should not remove the original folder + self.assertTrue(path.exists()) + + def test_mount_errors(self): + """Test invalid mount errors.""" + cases = [ + (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"), + (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"), + ( + dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"), + OSError, + "Invalid argument", + ), + ] + for case, err, msg in cases: + with self.subTest(case=case): + with self.assertRaises(err) as cm: + mount(**case) + self.assertIn(msg, str(cm.exception)) + + def test_mount_duplicate(self): + """Test attempted mount after mounted.""" + path = Path(self.temp_dir.name, str(uuid4())) + path.mkdir() + try: + mount(source="", target=path, fs="tmpfs") + with self.assertRaises(OSError) as cm: + mount(source="", target=path, fs="tmpfs") + self.assertIn("already a mount point", str(cm.exception)) + finally: + unmount(target=path) + + def test_unmount_flags(self): + """Test unmount flags.""" + flags = [ + UnmountFlags.MNT_FORCE, + UnmountFlags.MNT_DETACH, + UnmountFlags.UMOUNT_NOFOLLOW, + ] + for flag in flags: + with self.subTest(flag=flag), self.get_mount() as path: + self.assertTrue(path.is_mount()) + unmount(path, flag) + self.assertFalse(path.is_mount()) + + def test_unmount_flags_expire(self): + """Test unmount MNT_EXPIRE behavior.""" + with self.get_mount() as path: + with self.assertRaises(BlockingIOError): + unmount(path, UnmountFlags.MNT_EXPIRE) + + def test_unmount_errors(self): + """Test invalid unmount errors.""" + cases = [ + (dict(target="not/exist"), OSError, "is not a mount point"), + (dict(target=Path("not/exist")), OSError, "is not a mount point"), + ] + for case, err, msg in cases: + with self.subTest(case=case): + with self.assertRaises(err) as cm: + unmount(**case) + self.assertIn(msg, str(cm.exception)) + + def test_unmount_invalid_args(self): + """Test invalid unmount invalid flag.""" + with self.get_mount() as path: + with self.assertRaises(OSError) as cm: + unmount(path, 251) + self.assertIn("Invalid argument", str(cm.exception)) + + def test_threading(self): + """Test concurrent mounting works in multi-thread environments.""" + paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)] + + for path in paths: + path.mkdir() + self.assertFalse(path.is_mount()) + + try: + with ThreadPoolExecutor() as pool: + res = list( + pool.map( + mount, + [""] * len(paths), + paths, + ["tmpfs"] * len(paths), + ) + ) + self.assertEqual(len(res), len(paths)) + + for path in paths: + with self.subTest(path=path): + self.assertTrue(path.is_mount()) + + unmounts = list(pool.map(unmount, paths)) + self.assertEqual(len(unmounts), len(paths)) + + for path in paths: + with self.subTest(path=path): + self.assertFalse(path.is_mount()) + finally: + with suppress(OSError): + for path in paths: + unmount(path) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7c5db2b..91b01e6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,14 +1,25 @@ import json import unittest import urllib.request +from base64 import b64encode from multiprocessing.dummy import Pool +from textwrap import dedent from tests.gunicorn_utils import run_gunicorn -def run_code_in_snekbox(code: str) -> tuple[str, int]: - body = {"input": code} - json_data = json.dumps(body).encode("utf-8") +def b64encode_code(data: str): + data = dedent(data).strip() + return b64encode(data.encode()).decode("ascii") + + +def snekbox_run_code(code: str) -> tuple[str, int]: + body = {"args": ["-c", code]} + return snekbox_request(body) + + +def snekbox_request(content: dict) -> tuple[str, int]: + json_data = json.dumps(content).encode("utf-8") req = urllib.request.Request("http://localhost:8060/eval") req.add_header("Content-Type", "application/json; charset=utf-8") @@ -34,9 +45,71 @@ class IntegrationTests(unittest.TestCase): args = [code] * processes with Pool(processes) as p: - results = p.map(run_code_in_snekbox, args) + results = p.map(snekbox_run_code, args) responses, statuses = zip(*results) self.assertTrue(all(status == 200 for status in statuses)) self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses)) + + def test_eval(self): + """Test normal eval requests without files.""" + with run_gunicorn(): + cases = [ + ({"input": "print('Hello')"}, "Hello\n"), + ({"args": ["-c", "print('abc12')"]}, "abc12\n"), + ] + for body, expected in cases: + with self.subTest(body=body): + response, status = snekbox_request(body) + self.assertEqual(status, 200) + self.assertEqual(json.loads(response)["stdout"], expected) + + def test_files_send_receive(self): + """Test sending and receiving files to snekbox.""" + with run_gunicorn(): + request = { + "args": ["main.py"], + "files": [ + { + "path": "main.py", + "content": b64encode_code( + """ + from pathlib import Path + from mod import lib + print(lib.var) + + with open('test.txt', 'w') as f: + f.write('test 1') + + Path('dir').mkdir() + Path('dir/test2.txt').write_text('test 2') + """ + ), + }, + {"path": "mod/__init__.py"}, + {"path": "mod/lib.py", "content": b64encode_code("var = 'hello'")}, + ], + } + + expected = { + "stdout": "hello\n", + "returncode": 0, + "files": [ + { + "path": "dir/test2.txt", + "size": len("test 2"), + "content": b64encode_code("test 2"), + }, + { + "path": "test.txt", + "size": len("test 1"), + "content": b64encode_code("test 1"), + }, + ], + } + + response, status = snekbox_request(request) + + self.assertEqual(200, status) + self.assertEqual(expected, json.loads(response)) diff --git a/tests/test_main.py b/tests/test_main.py index 77b3130..24c067c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase): @patch("sys.argv", ["", "import sys; sys.exit(22)"]) def test_main_exits_with_returncode(self): - """Should exit with the subprocess's returncode if it's non-zero.""" + """Should exit with the subprocess returncode if it's non-zero.""" with self.assertRaises(SystemExit) as cm: snekbox_main.main() diff --git a/tests/test_memfs.py b/tests/test_memfs.py new file mode 100644 index 0000000..0555726 --- /dev/null +++ b/tests/test_memfs.py @@ -0,0 +1,65 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack +from unittest import TestCase, mock +from uuid import uuid4 + +from snekbox.memfs import MemFS + +UUID_TEST = uuid4() + + +class MemFSTests(TestCase): + def setUp(self): + super().setUp() + self.logger = logging.getLogger("snekbox.memfs") + self.logger.setLevel(logging.WARNING) + + @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST) + def test_assignment_thread_safe(self): + """Test concurrent mounting works in multi-thread environments.""" + # Concurrently create MemFS in threads, check only 1 can be created + # Others should result in RuntimeError + with ExitStack() as stack: + with ThreadPoolExecutor() as executor: + memfs: MemFS | None = None + # Each future uses enter_context to ensure __exit__ on test exception + futures = [ + executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8) + ] + for future in futures: + # We should have exactly one result and all others RuntimeErrors + if err := future.exception(): + self.assertIsInstance(err, RuntimeError) + else: + self.assertIsNone(memfs) + memfs = future.result() + + # Original memfs should still exist afterwards + self.assertIsInstance(memfs, MemFS) + self.assertTrue(memfs.path.is_mount()) + + def test_cleanup(self): + """Test explicit cleanup.""" + memfs = MemFS(10) + path = memfs.path + self.assertTrue(path.is_mount()) + memfs.cleanup() + self.assertFalse(path.exists()) + + def test_context_cleanup(self): + """Context __exit__ should trigger cleanup.""" + with MemFS(10) as memfs: + path = memfs.path + self.assertTrue(path.is_mount()) + self.assertFalse(path.exists()) + + def test_implicit_cleanup(self): + """Test implicit _cleanup triggered by GC.""" + memfs = MemFS(10) + path = memfs.path + self.assertTrue(path.is_mount()) + # Catch the warning about implicit cleanup + with self.assertWarns(ResourceWarning): + del memfs + self.assertFalse(path.exists()) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 6f6e2a7..cad79f3 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -9,28 +9,40 @@ from itertools import product from pathlib import Path from textwrap import dedent +from snekbox.filesystem import Size from snekbox.nsjail import NsJail +from snekbox.snekio import FileAttachment class NsJailTests(unittest.TestCase): def setUp(self): super().setUp() - self.nsjail = NsJail() + # Specify lower limits for unit tests to complete within time limits + self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB) self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) + def eval_code(self, code: str): + return self.nsjail.python3(["-c", code]) + + def eval_file(self, code: str, name: str = "test.py", **kwargs): + file = FileAttachment(name, code.encode()) + return self.nsjail.python3([name], [file], **kwargs) + def test_print_returns_0(self): - result = self.nsjail.python3("print('test')") - self.assertEqual(result.returncode, 0) - self.assertEqual(result.stdout, "test\n") - self.assertEqual(result.stderr, None) + for fn in (self.eval_code, self.eval_file): + with self.subTest(fn.__name__): + result = fn("print('test')") + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "test\n") + self.assertEqual(result.stderr, None) def test_timeout_returns_137(self): code = "while True: pass" with self.assertLogs(self.logger) as log: - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 137) self.assertEqual(result.stdout, "") @@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase): # Add a kilobyte just to be safe. code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" - result = self.nsjail.python3(code) - self.assertEqual(result.returncode, 137) + result = self.eval_file(code) self.assertEqual(result.stdout, "") + self.assertEqual(result.returncode, 137) + self.assertEqual(result.stderr, None) + + def test_multi_files(self): + files = [ + FileAttachment("main.py", "import lib; print(lib.x)".encode()), + FileAttachment("lib.py", "x = 'hello'".encode()), + ] + + result = self.nsjail.python3(["main.py"], files) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") self.assertEqual(result.stderr, None) def test_subprocess_resource_unavailable(self): + max_pids = self.nsjail.config.cgroup_pids_max code = dedent( - """ + f""" import subprocess - # Max PIDs is 5. - for _ in range(6): + # Should fail at n (max PIDs) since the caller python process counts as well + for _ in range({max_pids}): print(subprocess.Popen( [ '/usr/local/bin/python3', @@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) + # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n" + expected = "\n".join(map(str, range(2, max_pids + 1))) + self.assertIn(expected, result.stdout) self.assertEqual(result.stderr, None) def test_multiprocess_resource_limits(self): @@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase): """ ) - result = self.nsjail.python3(code) + result = self.eval_file(code) exit_codes = result.stdout.strip().split() self.assertIn("-9", exit_codes) @@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Read-only file system", result.stdout) self.assertEqual(result.stderr, None) + def test_write(self): + code = dedent( + """ + from pathlib import Path + with open('test.txt', 'w') as f: + f.write('hello') + print(Path('test.txt').read_text()) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") + self.assertEqual(result.stderr, None) + + def test_write_exceed_space(self): + code = dedent( + f""" + size = {self.nsjail.memfs_instance_size} // 2048 + with open('f.bin', 'wb') as f: + for i in range(size): + f.write(b'1' * 2048) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 1) + self.assertIn("No space left on device", result.stdout) + self.assertEqual(result.stderr, None) + def test_forkbomb_resource_unavailable(self): code = dedent( """ @@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) + def test_file_parsing_timeout(self): + code = dedent( + """ + import os + data = "a" * 1024 + size = 32 * 1024 * 1024 + + with open("file", "w") as f: + for _ in range((size // 1024) - 5): + f.write(data) + + for i in range(100): + os.symlink("file", f"file{i}") + """ + ).strip() + + nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1) + result = nsjail.python3(["-c", code]) + self.assertEqual(result.returncode, None) + self.assertEqual( + result.stdout, "TimeoutError: Exceeded time limit while parsing attachments" + ) + self.assertEqual(result.stderr, None) + + def test_file_write_error(self): + """Test errors during file write.""" + result = self.nsjail.python3( + [""], + [ + FileAttachment("dir/test.txt", b"abc"), + FileAttachment("dir", b"xyz"), + ], + ) + + self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.") + self.assertEqual(result.stderr, None) + self.assertEqual(result.returncode, None) + def test_sigsegv_returns_139(self): # In honour of Juan. code = dedent( """ @@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 139) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) def test_null_byte_value_error(self): - result = self.nsjail.python3("\0") + # This error only occurs with `-c` mode + result = self.eval_code("\0") self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "ValueError: embedded null byte") self.assertEqual(result.stderr, None) def test_print_bad_unicode_encode_error(self): - result = self.nsjail.python3("print(chr(56550))") + result = self.eval_file("print(chr(56550))") self.assertEqual(result.returncode, 1) self.assertIn("UnicodeEncodeError", result.stdout) self.assertEqual(result.stderr, None) def test_unicode_env_erase_escape_fails(self): - result = self.nsjail.python3( + result = self.eval_file( dedent( """ import os @@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No such file or directory", result.stdout) self.assertEqual(result.stderr, None) @@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Function not implemented", result.stdout) self.assertEqual(result.stderr, None) def test_numpy_import(self): - result = self.nsjail.python3("import numpy") + result = self.eval_file("import numpy") self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertLess( result.stdout.find(stdout_msg), result.stdout.find(stderr_msg), @@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase): def test_stdout_flood_results_in_graceful_sigterm(self): code = "while True: print('abcdefghij')" - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): @@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase): self.assertEqual(output, chunk * expected_chunks) def test_nsjail_args(self): - args = ("foo", "bar") - result = self.nsjail.python3("", nsjail_args=args) + args = ["foo", "bar"] + result = self.nsjail.python3((), nsjail_args=args) end = result.args.index("--") self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): - args = ("-m", "timeit") - result = self.nsjail.python3("", py_args=args) - - self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-3:-1], args) + cases = [ + # Normal args + (["-c", "print('hello')"], ["-c", "print('hello')"]), + # Leading empty strings should be removed + (["", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "", "-m", "timeit"], ["-m", "timeit"]), + # Non-leading empty strings should be preserved + (["-m", "timeit", ""], ["-m", "timeit", ""]), + ] + + for args, expected in cases: + with self.subTest(args=args): + result = self.nsjail.python3(py_args=args) + idx = result.args.index("-BSqu") + self.assertEqual(result.args[idx + 1 :], expected) + self.assertEqual(result.returncode, 0) class NsJailArgsTests(unittest.TestCase): diff --git a/tests/test_snekio.py b/tests/test_snekio.py new file mode 100644 index 0000000..8f04429 --- /dev/null +++ b/tests/test_snekio.py @@ -0,0 +1,58 @@ +from unittest import TestCase + +from snekbox import snekio +from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError + + +class SnekIOTests(TestCase): + def test_safe_path(self) -> None: + cases = [ + ("", ""), + ("foo", "foo"), + ("foo/bar", "foo/bar"), + ("foo/bar.ext", "foo/bar.ext"), + ] + + for path, expected in cases: + self.assertEqual(snekio.safe_path(path), expected) + + def test_safe_path_raise(self): + cases = [ + ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"), + ("/foo", IllegalPathError, "File path '/foo' must be relative"), + ] + + for path, error, msg in cases: + with self.assertRaises(error) as cm: + snekio.safe_path(path) + self.assertEqual(str(cm.exception), msg) + + def test_file_from_dict(self): + cases = [ + ({"path": "foo", "content": ""}, FileAttachment("foo", b"")), + ({"path": "foo"}, FileAttachment("foo", b"")), + ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")), + ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")), + ] + + for data, expected in cases: + self.assertEqual(FileAttachment.from_dict(data), expected) + + def test_file_from_dict_error(self): + cases = [ + ( + {"path": "foo", "content": "9"}, + ParsingError, + "Invalid base64 encoding for file 'foo'", + ), + ( + {"path": "var/a.txt", "content": "1="}, + ParsingError, + "Invalid base64 encoding for file 'var/a.txt'", + ), + ] + + for data, error, msg in cases: + with self.assertRaises(error) as cm: + FileAttachment.from_dict(data) + self.assertEqual(str(cm.exception), msg) |