aboutsummaryrefslogtreecommitdiffstats
path: root/tests/test_nsjail.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nsjail.py')
-rw-r--r--tests/test_nsjail.py170
1 files changed, 139 insertions, 31 deletions
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 6f6e2a7..cad79f3 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -9,28 +9,40 @@ from itertools import product
from pathlib import Path
from textwrap import dedent
+from snekbox.filesystem import Size
from snekbox.nsjail import NsJail
+from snekbox.snekio import FileAttachment
class NsJailTests(unittest.TestCase):
def setUp(self):
super().setUp()
- self.nsjail = NsJail()
+ # Specify lower limits for unit tests to complete within time limits
+ self.nsjail = NsJail(memfs_instance_size=2 * Size.MiB)
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)
+ def eval_code(self, code: str):
+ return self.nsjail.python3(["-c", code])
+
+ def eval_file(self, code: str, name: str = "test.py", **kwargs):
+ file = FileAttachment(name, code.encode())
+ return self.nsjail.python3([name], [file], **kwargs)
+
def test_print_returns_0(self):
- result = self.nsjail.python3("print('test')")
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.stdout, "test\n")
- self.assertEqual(result.stderr, None)
+ for fn in (self.eval_code, self.eval_file):
+ with self.subTest(fn.__name__):
+ result = fn("print('test')")
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "test\n")
+ self.assertEqual(result.stderr, None)
def test_timeout_returns_137(self):
code = "while True: pass"
with self.assertLogs(self.logger) as log:
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 137)
self.assertEqual(result.stdout, "")
@@ -41,18 +53,30 @@ class NsJailTests(unittest.TestCase):
# Add a kilobyte just to be safe.
code = f"x = ' ' * {self.nsjail.config.cgroup_mem_max + 1000}"
- result = self.nsjail.python3(code)
- self.assertEqual(result.returncode, 137)
+ result = self.eval_file(code)
self.assertEqual(result.stdout, "")
+ self.assertEqual(result.returncode, 137)
+ self.assertEqual(result.stderr, None)
+
+ def test_multi_files(self):
+ files = [
+ FileAttachment("main.py", "import lib; print(lib.x)".encode()),
+ FileAttachment("lib.py", "x = 'hello'".encode()),
+ ]
+
+ result = self.nsjail.python3(["main.py"], files)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
self.assertEqual(result.stderr, None)
def test_subprocess_resource_unavailable(self):
+ max_pids = self.nsjail.config.cgroup_pids_max
code = dedent(
- """
+ f"""
import subprocess
- # Max PIDs is 5.
- for _ in range(6):
+ # Should fail at n (max PIDs) since the caller python process counts as well
+ for _ in range({max_pids}):
print(subprocess.Popen(
[
'/usr/local/bin/python3',
@@ -63,9 +87,12 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
+ # Expect n-1 processes to be opened by the presence of string like "2\n3\n4\n"
+ expected = "\n".join(map(str, range(2, max_pids + 1)))
+ self.assertIn(expected, result.stdout)
self.assertEqual(result.stderr, None)
def test_multiprocess_resource_limits(self):
@@ -92,7 +119,7 @@ class NsJailTests(unittest.TestCase):
"""
)
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
exit_codes = result.stdout.strip().split()
self.assertIn("-9", exit_codes)
@@ -108,11 +135,41 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Read-only file system", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_write(self):
+ code = dedent(
+ """
+ from pathlib import Path
+ with open('test.txt', 'w') as f:
+ f.write('hello')
+ print(Path('test.txt').read_text())
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.stdout, "hello\n")
+ self.assertEqual(result.stderr, None)
+
+ def test_write_exceed_space(self):
+ code = dedent(
+ f"""
+ size = {self.nsjail.memfs_instance_size} // 2048
+ with open('f.bin', 'wb') as f:
+ for i in range(size):
+ f.write(b'1' * 2048)
+ """
+ ).strip()
+
+ result = self.eval_file(code)
+ self.assertEqual(result.returncode, 1)
+ self.assertIn("No space left on device", result.stdout)
+ self.assertEqual(result.stderr, None)
+
def test_forkbomb_resource_unavailable(self):
code = dedent(
"""
@@ -122,11 +179,49 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Resource temporarily unavailable", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_timeout(self):
+ code = dedent(
+ """
+ import os
+ data = "a" * 1024
+ size = 32 * 1024 * 1024
+
+ with open("file", "w") as f:
+ for _ in range((size // 1024) - 5):
+ f.write(data)
+
+ for i in range(100):
+ os.symlink("file", f"file{i}")
+ """
+ ).strip()
+
+ nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, None)
+ self.assertEqual(
+ result.stdout, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
+ self.assertEqual(result.stderr, None)
+
+ def test_file_write_error(self):
+ """Test errors during file write."""
+ result = self.nsjail.python3(
+ [""],
+ [
+ FileAttachment("dir/test.txt", b"abc"),
+ FileAttachment("dir", b"xyz"),
+ ],
+ )
+
+ self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'dir'.")
+ self.assertEqual(result.stderr, None)
+ self.assertEqual(result.returncode, None)
+
def test_sigsegv_returns_139(self): # In honour of Juan.
code = dedent(
"""
@@ -135,25 +230,26 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 139)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
def test_null_byte_value_error(self):
- result = self.nsjail.python3("\0")
+ # This error only occurs with `-c` mode
+ result = self.eval_code("\0")
self.assertEqual(result.returncode, None)
self.assertEqual(result.stdout, "ValueError: embedded null byte")
self.assertEqual(result.stderr, None)
def test_print_bad_unicode_encode_error(self):
- result = self.nsjail.python3("print(chr(56550))")
+ result = self.eval_file("print(chr(56550))")
self.assertEqual(result.returncode, 1)
self.assertIn("UnicodeEncodeError", result.stdout)
self.assertEqual(result.stderr, None)
def test_unicode_env_erase_escape_fails(self):
- result = self.nsjail.python3(
+ result = self.eval_file(
dedent(
"""
import os
@@ -207,7 +303,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("No such file or directory", result.stdout)
self.assertEqual(result.stderr, None)
@@ -223,13 +319,13 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 1)
self.assertIn("Function not implemented", result.stdout)
self.assertEqual(result.stderr, None)
def test_numpy_import(self):
- result = self.nsjail.python3("import numpy")
+ result = self.eval_file("import numpy")
self.assertEqual(result.returncode, 0)
self.assertEqual(result.stdout, "")
self.assertEqual(result.stderr, None)
@@ -244,7 +340,7 @@ class NsJailTests(unittest.TestCase):
"""
).strip()
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertLess(
result.stdout.find(stdout_msg),
result.stdout.find(stderr_msg),
@@ -255,7 +351,7 @@ class NsJailTests(unittest.TestCase):
def test_stdout_flood_results_in_graceful_sigterm(self):
code = "while True: print('abcdefghij')"
- result = self.nsjail.python3(code)
+ result = self.eval_file(code)
self.assertEqual(result.returncode, 143)
def test_large_output_is_truncated(self):
@@ -272,18 +368,30 @@ class NsJailTests(unittest.TestCase):
self.assertEqual(output, chunk * expected_chunks)
def test_nsjail_args(self):
- args = ("foo", "bar")
- result = self.nsjail.python3("", nsjail_args=args)
+ args = ["foo", "bar"]
+ result = self.nsjail.python3((), nsjail_args=args)
end = result.args.index("--")
self.assertEqual(result.args[end - len(args) : end], args)
def test_py_args(self):
- args = ("-m", "timeit")
- result = self.nsjail.python3("", py_args=args)
-
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.args[-3:-1], args)
+ cases = [
+ # Normal args
+ (["-c", "print('hello')"], ["-c", "print('hello')"]),
+ # Leading empty strings should be removed
+ (["", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "-m", "timeit"], ["-m", "timeit"]),
+ (["", "", "", "-m", "timeit"], ["-m", "timeit"]),
+ # Non-leading empty strings should be preserved
+ (["-m", "timeit", ""], ["-m", "timeit", ""]),
+ ]
+
+ for args, expected in cases:
+ with self.subTest(args=args):
+ result = self.nsjail.python3(py_args=args)
+ idx = result.args.index("-BSqu")
+ self.assertEqual(result.args[idx + 1 :], expected)
+ self.assertEqual(result.returncode, 0)
class NsJailArgsTests(unittest.TestCase):