diff options
| -rw-r--r-- | scripts/.profile | 9 | ||||
| -rw-r--r-- | snekbox/nsjail.py | 14 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 19 | 
3 files changed, 34 insertions, 8 deletions
| diff --git a/scripts/.profile b/scripts/.profile index 415e4f6..bff260d 100644 --- a/scripts/.profile +++ b/scripts/.profile @@ -1,12 +1,19 @@  nsjpy() { +    local MEM_MAX=52428800 + +    # All arguments except the last are considered to be for NsJail, not Python.      local nsj_args=""      while [ "$#" -gt 1 ]; do          nsj_args="${nsj_args:+${nsj_args} }$1"          shift      done +    # Set up cgroups and disable memory swapping.      mkdir -p /sys/fs/cgroup/pids/NSJAIL      mkdir -p /sys/fs/cgroup/memory/NSJAIL +    echo "${MEM_MAX}" > /sys/fs/cgroup/memory/NSJAIL/memory.limit_in_bytes +    echo "${MEM_MAX}" > /sys/fs/cgroup/memory/NSJAIL/memory.memsw.limit_in_bytes +      nsjail \          -Mo \          --rlimit_as 700 \ @@ -19,7 +26,7 @@ nsjpy() {          --disable_proc \          --iface_no_lo \          --cgroup_pids_max=1 \ -        --cgroup_mem_max=52428800 \ +        --cgroup_mem_max="${MEM_MAX}" \          $nsj_args -- \          /snekbox/.venv/bin/python3 -Iq -c "$@"  } diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index b68b0b9..b9c4fc7 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -24,6 +24,7 @@ CGROUP_PIDS_PARENT = Path("/sys/fs/cgroup/pids/NSJAIL")  CGROUP_MEMORY_PARENT = Path("/sys/fs/cgroup/memory/NSJAIL")  NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail") +MEM_MAX = 52428800  class NsJail: @@ -59,10 +60,21 @@ class NsJail:          NsJail doesn't do this automatically because it requires privileges NsJail usually doesn't          have. + +        Disables memory swapping.          """          pids.mkdir(parents=True, exist_ok=True)          mem.mkdir(parents=True, exist_ok=True) +        # Swap limit cannot be set to a value lower than memory.limit_in_bytes. +        # Therefore, this must be set first. +        with (mem / "memory.limit_in_bytes").open("w", encoding="utf=8") as f: +            f.write(str(MEM_MAX)) + +        # Swap limit is specified as the sum of the memory and swap limits. +        with (mem / "memory.memsw.limit_in_bytes").open("w", encoding="utf=8") as f: +            f.write(str(MEM_MAX)) +      @staticmethod      def _parse_log(log_lines: Iterable[str]):          """Parse and log NsJail's log messages.""" @@ -108,7 +120,7 @@ class NsJail:                  "--disable_proc",                  "--iface_no_lo",                  "--log", nsj_log.name, -                "--cgroup_mem_max=52428800", +                f"--cgroup_mem_max={MEM_MAX}",                  "--cgroup_mem_mount", str(CGROUP_MEMORY_PARENT.parent),                  "--cgroup_mem_parent", CGROUP_MEMORY_PARENT.name,                  "--cgroup_pids_max=1", diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index e3b8eb3..f1a60e6 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -2,7 +2,7 @@ import logging  import unittest  from textwrap import dedent -from snekbox.nsjail import NsJail +from snekbox.nsjail import MEM_MAX, NsJail  class NsJailTests(unittest.TestCase): @@ -21,12 +21,8 @@ class NsJailTests(unittest.TestCase):      def test_timeout_returns_137(self):          code = dedent(""" -            x = '*'              while True: -                try: -                    x = x * 99 -                except: -                    continue +                pass          """).strip()          with self.assertLogs(self.logger) as log: @@ -37,6 +33,17 @@ class NsJailTests(unittest.TestCase):          self.assertEqual(result.stderr, None)          self.assertIn("run time >= time limit", "\n".join(log.output)) +    def test_memory_returns_137(self): +        # Add a kilobyte just to be safe. +        code = dedent(f""" +            x = ' ' * {MEM_MAX + 1000} +        """).strip() + +        result = self.nsjail.python3(code) +        self.assertEqual(result.returncode, 137) +        self.assertEqual(result.stdout, "") +        self.assertEqual(result.stderr, None) +      def test_subprocess_resource_unavailable(self):          code = dedent("""              import subprocess | 
