diff options
-rw-r--r-- | scripts/.profile | 9 | ||||
-rw-r--r-- | snekbox/nsjail.py | 14 | ||||
-rw-r--r-- | tests/test_nsjail.py | 19 |
3 files changed, 34 insertions, 8 deletions
diff --git a/scripts/.profile b/scripts/.profile index 415e4f6..bff260d 100644 --- a/scripts/.profile +++ b/scripts/.profile @@ -1,12 +1,19 @@ nsjpy() { + local MEM_MAX=52428800 + + # All arguments except the last are considered to be for NsJail, not Python. local nsj_args="" while [ "$#" -gt 1 ]; do nsj_args="${nsj_args:+${nsj_args} }$1" shift done + # Set up cgroups and disable memory swapping. mkdir -p /sys/fs/cgroup/pids/NSJAIL mkdir -p /sys/fs/cgroup/memory/NSJAIL + echo "${MEM_MAX}" > /sys/fs/cgroup/memory/NSJAIL/memory.limit_in_bytes + echo "${MEM_MAX}" > /sys/fs/cgroup/memory/NSJAIL/memory.memsw.limit_in_bytes + nsjail \ -Mo \ --rlimit_as 700 \ @@ -19,7 +26,7 @@ nsjpy() { --disable_proc \ --iface_no_lo \ --cgroup_pids_max=1 \ - --cgroup_mem_max=52428800 \ + --cgroup_mem_max="${MEM_MAX}" \ $nsj_args -- \ /snekbox/.venv/bin/python3 -Iq -c "$@" } diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index b68b0b9..b9c4fc7 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -24,6 +24,7 @@ CGROUP_PIDS_PARENT = Path("/sys/fs/cgroup/pids/NSJAIL") CGROUP_MEMORY_PARENT = Path("/sys/fs/cgroup/memory/NSJAIL") NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail") +MEM_MAX = 52428800 class NsJail: @@ -59,10 +60,21 @@ class NsJail: NsJail doesn't do this automatically because it requires privileges NsJail usually doesn't have. + + Disables memory swapping. """ pids.mkdir(parents=True, exist_ok=True) mem.mkdir(parents=True, exist_ok=True) + # Swap limit cannot be set to a value lower than memory.limit_in_bytes. + # Therefore, this must be set first. + with (mem / "memory.limit_in_bytes").open("w", encoding="utf=8") as f: + f.write(str(MEM_MAX)) + + # Swap limit is specified as the sum of the memory and swap limits. + with (mem / "memory.memsw.limit_in_bytes").open("w", encoding="utf=8") as f: + f.write(str(MEM_MAX)) + @staticmethod def _parse_log(log_lines: Iterable[str]): """Parse and log NsJail's log messages.""" @@ -108,7 +120,7 @@ class NsJail: "--disable_proc", "--iface_no_lo", "--log", nsj_log.name, - "--cgroup_mem_max=52428800", + f"--cgroup_mem_max={MEM_MAX}", "--cgroup_mem_mount", str(CGROUP_MEMORY_PARENT.parent), "--cgroup_mem_parent", CGROUP_MEMORY_PARENT.name, "--cgroup_pids_max=1", diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index e3b8eb3..f1a60e6 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -2,7 +2,7 @@ import logging import unittest from textwrap import dedent -from snekbox.nsjail import NsJail +from snekbox.nsjail import MEM_MAX, NsJail class NsJailTests(unittest.TestCase): @@ -21,12 +21,8 @@ class NsJailTests(unittest.TestCase): def test_timeout_returns_137(self): code = dedent(""" - x = '*' while True: - try: - x = x * 99 - except: - continue + pass """).strip() with self.assertLogs(self.logger) as log: @@ -37,6 +33,17 @@ class NsJailTests(unittest.TestCase): self.assertEqual(result.stderr, None) self.assertIn("run time >= time limit", "\n".join(log.output)) + def test_memory_returns_137(self): + # Add a kilobyte just to be safe. + code = dedent(f""" + x = ' ' * {MEM_MAX + 1000} + """).strip() + + result = self.nsjail.python3(code) + self.assertEqual(result.returncode, 137) + self.assertEqual(result.stdout, "") + self.assertEqual(result.stderr, None) + def test_subprocess_resource_unavailable(self): code = dedent(""" import subprocess |