diff options
author | 2023-03-10 02:03:01 +0200 | |
---|---|---|
committer | 2023-03-10 02:03:01 +0200 | |
commit | 8a85b86174067618891e2530dd56b5025fd2f28b (patch) | |
tree | cc85ad898f492c0b799917cda05537c88bea83c6 /tests/test_nsjail.py | |
parent | Merge pull request #167 from python-discord/deployment-update (diff) | |
parent | Merge branch 'main' into bytes-output (diff) |
Merge pull request #159 from python-discord/bytes-output
File system and Binary file sending
Diffstat (limited to 'tests/test_nsjail.py')
-rw-r--r-- | tests/test_nsjail.py | 170 |
1 files changed, 139 insertions, 31 deletions
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 6f6e2a7..cad79f3 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -9,28 +9,40 @@ from itertools import product from pathlib import Path from textwrap import dedent +from snekbox.filesystem import Size from snekbox.nsjail import NsJail +from snekbox.snekio import FileAttachment class NsJailTests(unittest.TestCase): def setUp(self): super().setUp() - self.nsjail = NsJail() + # Specify lower limits for unit tests to complete within time limits + self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB) self.logger = logging.getLogger("snekbox.nsjail") self.logger.setLevel(logging.WARNING) + def eval_code(self, code: str): + return self.nsjail.python3(["-c", code]) + + def eval_file(self, code: str, name: str = "test.py", **kwargs): + file = FileAttachment(name, code.encode()) + return self.nsjail.python3([name], [file], **kwargs) + def test_print_returns_0(self): - result = self.nsjail.python3("print('test')") - self.assertEqual(result.returncode, 0) - self.assertEqual(result.stdout, "test\n") - self.assertEqual(result.stderr, None) + for fn in (self.eval_code, self.eval_file): + with self.subTest(fn.__name__): + result = fn("print('test')") + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "test\n") + self.assertEqual(result.stderr, None) def test_timeout_returns_137(self): code = "while True: pass" with self.assertLogs(self.logger) as log: - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 137) self.assertEqual(result.stdout, "") @@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase): # Add a kilobyte just to be safe. code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}" - result = self.nsjail.python3(code) - self.assertEqual(result.returncode, 137) + result = self.eval_file(code) self.assertEqual(result.stdout, "") + self.assertEqual(result.returncode, 137) + self.assertEqual(result.stderr, None) + + def test_multi_files(self): + files = [ + FileAttachment("main.py", "import lib; print(lib.x)".encode()), + FileAttachment("lib.py", "x = 'hello'".encode()), + ] + + result = self.nsjail.python3(["main.py"], files) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") self.assertEqual(result.stderr, None) def test_subprocess_resource_unavailable(self): + max_pids = self.nsjail.config.cgroup_pids_max code = dedent( - """ + f""" import subprocess - # Max PIDs is 5. - for _ in range(6): + # Should fail at n (max PIDs) since the caller python process counts as well + for _ in range({max_pids}): print(subprocess.Popen( [ '/usr/local/bin/python3', @@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) + # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n" + expected = "\n".join(map(str, range(2, max_pids + 1))) + self.assertIn(expected, result.stdout) self.assertEqual(result.stderr, None) def test_multiprocess_resource_limits(self): @@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase): """ ) - result = self.nsjail.python3(code) + result = self.eval_file(code) exit_codes = result.stdout.strip().split() self.assertIn("-9", exit_codes) @@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Read-only file system", result.stdout) self.assertEqual(result.stderr, None) + def test_write(self): + code = dedent( + """ + from pathlib import Path + with open('test.txt', 'w') as f: + f.write('hello') + print(Path('test.txt').read_text()) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.stdout, "hello\n") + self.assertEqual(result.stderr, None) + + def test_write_exceed_space(self): + code = dedent( + f""" + size = {self.nsjail.memfs_instance_size} // 2048 + with open('f.bin', 'wb') as f: + for i in range(size): + f.write(b'1' * 2048) + """ + ).strip() + + result = self.eval_file(code) + self.assertEqual(result.returncode, 1) + self.assertIn("No space left on device", result.stdout) + self.assertEqual(result.stderr, None) + def test_forkbomb_resource_unavailable(self): code = dedent( """ @@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Resource temporarily unavailable", result.stdout) self.assertEqual(result.stderr, None) + def test_file_parsing_timeout(self): + code = dedent( + """ + import os + data = "a" * 1024 + size = 32 * 1024 * 1024 + + with open("file", "w") as f: + for _ in range((size // 1024) - 5): + f.write(data) + + for i in range(100): + os.symlink("file", f"file{i}") + """ + ).strip() + + nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1) + result = nsjail.python3(["-c", code]) + self.assertEqual(result.returncode, None) + self.assertEqual( + result.stdout, "TimeoutError: Exceeded time limit while parsing attachments" + ) + self.assertEqual(result.stderr, None) + + def test_file_write_error(self): + """Test errors during file write.""" + result = self.nsjail.python3( + [""], + [ + FileAttachment("dir/test.txt", b"abc"), + FileAttachment("dir", b"xyz"), + ], + ) + + self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.") + self.assertEqual(result.stderr, None) + self.assertEqual(result.returncode, None) + def test_sigsegv_returns_139(self): # In honour of Juan. code = dedent( """ @@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 139) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) def test_null_byte_value_error(self): - result = self.nsjail.python3("\0") + # This error only occurs with `-c` mode + result = self.eval_code("\0") self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "ValueError: embedded null byte") self.assertEqual(result.stderr, None) def test_print_bad_unicode_encode_error(self): - result = self.nsjail.python3("print(chr(56550))") + result = self.eval_file("print(chr(56550))") self.assertEqual(result.returncode, 1) self.assertIn("UnicodeEncodeError", result.stdout) self.assertEqual(result.stderr, None) def test_unicode_env_erase_escape_fails(self): - result = self.nsjail.python3( + result = self.eval_file( dedent( """ import os @@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("No such file or directory", result.stdout) self.assertEqual(result.stderr, None) @@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 1) self.assertIn("Function not implemented", result.stdout) self.assertEqual(result.stderr, None) def test_numpy_import(self): - result = self.nsjail.python3("import numpy") + result = self.eval_file("import numpy") self.assertEqual(result.returncode, 0) self.assertEqual(result.stdout, "") self.assertEqual(result.stderr, None) @@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase): """ ).strip() - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertLess( result.stdout.find(stdout_msg), result.stdout.find(stderr_msg), @@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase): def test_stdout_flood_results_in_graceful_sigterm(self): code = "while True: print('abcdefghij')" - result = self.nsjail.python3(code) + result = self.eval_file(code) self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): @@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase): self.assertEqual(output, chunk * expected_chunks) def test_nsjail_args(self): - args = ("foo", "bar") - result = self.nsjail.python3("", nsjail_args=args) + args = ["foo", "bar"] + result = self.nsjail.python3((), nsjail_args=args) end = result.args.index("--") self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): - args = ("-m", "timeit") - result = self.nsjail.python3("", py_args=args) - - self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-3:-1], args) + cases = [ + # Normal args + (["-c", "print('hello')"], ["-c", "print('hello')"]), + # Leading empty strings should be removed + (["", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "-m", "timeit"], ["-m", "timeit"]), + (["", "", "", "-m", "timeit"], ["-m", "timeit"]), + # Non-leading empty strings should be preserved + (["-m", "timeit", ""], ["-m", "timeit", ""]), + ] + + for args, expected in cases: + with self.subTest(args=args): + result = self.nsjail.python3(py_args=args) + idx = result.args.index("-BSqu") + self.assertEqual(result.args[idx + 1 :], expected) + self.assertEqual(result.returncode, 0) class NsJailArgsTests(unittest.TestCase): |