diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_eval.py | 17 | ||||
-rw-r--r-- | tests/test_integration.py | 2 | ||||
-rw-r--r-- | tests/test_main.py | 6 | ||||
-rw-r--r-- | tests/test_nsjail.py | 60 |
4 files changed, 47 insertions, 38 deletions
diff --git a/tests/api/test_eval.py b/tests/api/test_eval.py index 976970e..caa848e 100644 --- a/tests/api/test_eval.py +++ b/tests/api/test_eval.py @@ -5,7 +5,7 @@ class TestEvalResource(SnekAPITestCase): PATH = "/eval" def test_post_valid_200(self): - body = {"input": "foo"} + body = {"args": ["-c", "print('output')"]} result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 200) @@ -20,26 +20,25 @@ class TestEvalResource(SnekAPITestCase): expected = { "title": "Request data failed validation", - "description": "'input' is a required property", + "description": "'args' is a required property", } self.assertEqual(expected, result.json) def test_post_invalid_data_400(self): - bodies = ({"input": 400}, {"input": "", "args": [400]}) - - for body in bodies: + bodies = ({"args": 400}, {"args": [], "files": [215]}) + expects = ["400 is not of type 'array'", "215 is not of type 'object'"] + for body, expected in zip(bodies, expects): with self.subTest(): result = self.simulate_post(self.PATH, json=body) self.assertEqual(result.status_code, 400) - expected = { + expected_json = { "title": "Request data failed validation", - "description": "400 is not of type 'string'", + "description": expected, } - - self.assertEqual(expected, result.json) + self.assertEqual(expected_json, result.json) def test_post_invalid_content_type_415(self): body = "{'input': 'foo'}" diff --git a/tests/test_integration.py b/tests/test_integration.py index 7c5db2b..eba5e60 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,7 +7,7 @@ from tests.gunicorn_utils import run_gunicorn def run_code_in_snekbox(code: str) -> tuple[str, int]: - body = {"input": code} + body = {"args": ["-c", code]} json_data = json.dumps(body).encode("utf-8") req = urllib.request.Request("http://localhost:8060/eval") diff --git a/tests/test_main.py b/tests/test_main.py index 1e6cbc5..24c067c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,10 +12,10 @@ import snekbox.__main__ as snekbox_main class ArgParseTests(unittest.TestCase): def test_parse_args(self): subtests = ( - (["", "code"], Namespace(code="code", nsjail_args=[], py_args=[])), + (["", "code"], Namespace(code="code", nsjail_args=[], py_args=["-c"])), ( ["", "code", "--time_limit", "0"], - Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=[]), + Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=["-c"]), ), ( ["", "code", "---", "-m", "timeit"], @@ -63,7 +63,7 @@ class EntrypointTests(unittest.TestCase): @patch("sys.argv", ["", "import sys; sys.exit(22)"]) def test_main_exits_with_returncode(self): - """Should exit with the subprocess's returncode if it's non-zero.""" + """Should exit with the subprocess returncode if it's non-zero.""" with self.assertRaises(SystemExit) as cm: snekbox_main.main() diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index cea96bd..324b88a 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -10,6 +10,7 @@ from pathlib import Path from textwrap import dedent from snekbox.nsjail import NsJail +from snekbox.snekio import EvalRequestFile class NsJailTests(unittest.TestCase): @@ -20,17 +21,26 @@ class NsJailTests(unittest.TestCase): self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) + def eval_code(self, code: str): + return self.nsjail.python3(["-c", code]) + + def eval_file(self, code: str, name: str = "test.py"): + file = EvalRequestFile(name, code) + return self.nsjail.python3([name], [file]) + def test_print_returns_0(self): - result = self.nsjail.python3("print('test')") - self.assertEqual(result.returncode, 0) - self.assertEqual(result.stdout, "test\n") - self.assertEqual(result.stderr, None) + for fn in (self.eval_code, self.eval_file): + with self.subTest(fn.__name__): + result = fn("print('test')") + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "test\n") + self.assertEqual(result.stderr, None) def test_timeout_returns_137(self): code = "while True: pass" with self.assertLogs(self.logger) as log: - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 137) self.assertEqual(result.stdout, "") @@ -41,7 +51,7 @@ class NsJailTests(unittest.TestCase): # Add a kilobyte just to be safe. code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" - result = self.nsjail.python3(code, py_args=("-c",)) + result = self.eval_file(code) self.assertEqual(result.stdout, "") self.assertEqual(result.returncode, 137) self.assertEqual(result.stderr, None) @@ -64,7 +74,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) # Also expect n-1 processes to be opened @@ -96,7 +106,7 @@ class NsJailTests(unittest.TestCase): """ ) - result = self.nsjail.python3(code) + result = self.eval_file(code) exit_codes = result.stdout.strip().split() self.assertIn("-9", exit_codes) @@ -112,7 +122,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Read-only file system", result.stdout) self.assertEqual(result.stderr, None) @@ -127,7 +137,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "hello\n") self.assertEqual(result.stderr, None) @@ -145,7 +155,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No space left on device", result.stdout) self.assertEqual(result.stderr, None) @@ -165,7 +175,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 0) self.assertEqual( result.stdout, @@ -184,7 +194,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) @@ -197,7 +207,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 139) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -205,19 +215,19 @@ class NsJailTests(unittest.TestCase): def test_null_byte_value_error(self): # This error does not occur without -c, where it # would be a normal SyntaxError. - result = self.nsjail.python3("\0", py_args=("-c",)) + result = self.nsjail.python3(["-c", "\0"]) self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "ValueError: embedded null byte") self.assertEqual(result.stderr, None) def test_print_bad_unicode_encode_error(self): - result = self.nsjail.python3("print(chr(56550))") + result = self.eval_file("print(chr(56550))") self.assertEqual(result.returncode, 1) self.assertIn("UnicodeEncodeError", result.stdout) self.assertEqual(result.stderr, None) def test_unicode_env_erase_escape_fails(self): - result = self.nsjail.python3( + result = self.eval_file( dedent( """ import os @@ -271,7 +281,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No such file or directory", result.stdout) self.assertEqual(result.stderr, None) @@ -287,13 +297,13 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Function not implemented", result.stdout) self.assertEqual(result.stderr, None) def test_numpy_import(self): - result = self.nsjail.python3("import numpy") + result = self.eval_file("import numpy") self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -308,7 +318,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertLess( result.stdout.find(stdout_msg), result.stdout.find(stderr_msg), @@ -319,7 +329,7 @@ class NsJailTests(unittest.TestCase): def test_stdout_flood_results_in_graceful_sigterm(self): code = "while True: print('abcdefghij')" - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): @@ -337,17 +347,17 @@ class NsJailTests(unittest.TestCase): def test_nsjail_args(self): args = ["foo", "bar"] - result = self.nsjail.python3("", nsjail_args=args) + result = self.nsjail.python3((), nsjail_args=args) end = result.args.index("--") self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): args = ["-m", "timeit"] - result = self.nsjail.python3("", py_args=args) + result = self.nsjail.python3(args) self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-3:-1], args) + self.assertEqual(result.args[-2:], args) class NsJailArgsTests(unittest.TestCase): |