diff options
-rw-r--r-- | snekbox/utils/cgroup.py | 31 | ||||
-rw-r--r-- | tests/test_nsjail.py | 68 |
2 files changed, 93 insertions, 6 deletions
diff --git a/snekbox/utils/cgroup.py b/snekbox/utils/cgroup.py index b06cdfa..cc16178 100644 --- a/snekbox/utils/cgroup.py +++ b/snekbox/utils/cgroup.py @@ -65,16 +65,35 @@ def init(config: NsJailConfig) -> int: def init_v1(config: NsJailConfig) -> None: """ - Create a PID and memory cgroup for NsJail to use as the parent cgroup for each controller. + Create cgroups for NsJail to use as the parent cgroup for each in-use controller. + + A controller is in-use if any of its settings (except the mount and parent) have a non-default + value in the NsJail config. NsJail doesn't do this automatically because it requires privileges NsJail usually doesn't have. """ - pids = Path(config.cgroup_pids_mount, config.cgroup_pids_parent) - mem = Path(config.cgroup_mem_mount, config.cgroup_mem_parent) - - pids.mkdir(parents=True, exist_ok=True) - mem.mkdir(parents=True, exist_ok=True) + # If the config doesn't "have" a value, then it's set to the default value, which means the + # controller is not being used. + if config.HasField("cgroup_cpu_ms_per_sec"): + pids = Path(config.cgroup_cpu_mount, config.cgroup_cpu_parent) + pids.mkdir(parents=True, exist_ok=True) + + if ( + config.HasField("cgroup_mem_max") + or config.HasField("cgroup_mem_memsw_max") + or config.HasField("cgroup_mem_swap_max") + ): + mem = Path(config.cgroup_mem_mount, config.cgroup_mem_parent) + mem.mkdir(parents=True, exist_ok=True) + + if config.HasField("cgroup_net_cls_classid"): + net_cls = Path(config.cgroup_net_cls_mount, config.cgroup_net_cls_parent) + net_cls.mkdir(parents=True, exist_ok=True) + + if config.HasField("cgroup_pids_max"): + pids = Path(config.cgroup_pids_mount, config.cgroup_pids_parent) + pids.mkdir(parents=True, exist_ok=True) def init_v2(config: NsJailConfig) -> None: diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index 492c2f9..6f6e2a7 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -5,6 +5,8 @@ import sys import tempfile import unittest import unittest.mock +from itertools import product +from pathlib import Path from textwrap import dedent from snekbox.nsjail import NsJail @@ -316,3 +318,69 @@ class NsJailArgsTests(unittest.TestCase): self.assertEqual(self.nsjail.config_path, self.config_path) self.assertEqual(self.nsjail.max_output_size, self.max_output_size) self.assertEqual(self.nsjail.read_chunk_size, self.read_chunk_size) + + +class NsJailCgroupTests(unittest.TestCase): + # This should still pass for v2, even if this test isn't relevant. + def test_cgroupv1(self): + logging.getLogger("snekbox.nsjail").setLevel(logging.ERROR) + logging.getLogger("snekbox.utils.swap").setLevel(logging.ERROR) + + config_base = dedent( + """ + mode: ONCE + mount { + src: "/" + dst: "/" + is_bind: true + rw: false + } + exec_bin { + path: "/bin/su" + arg: "" + } + """ + ).strip() + + cases = ( + ( + ( + "cgroup_mem_max: 52428800", + # memory.limit_in_bytes must be set before memory.memsw.limit_in_bytes + "cgroup_mem_max: 52428800\ncgroup_mem_memsw_max: 104857600", + "cgroup_mem_max: 52428800\ncgroup_mem_swap_max: 52428800", + ), + "cgroup_mem_mount: '/sys/fs/cgroup/memory'", + "cgroup_mem_parent: 'NSJAILTEST1'", + ), + ( + ("cgroup_pids_max: 20",), + "cgroup_pids_mount: '/sys/fs/cgroup/pids'", + "cgroup_pids_parent: 'NSJAILTEST2'", + ), + ( + ("cgroup_net_cls_classid: 1048577",), + "cgroup_net_cls_mount: '/sys/fs/cgroup/net_cls'", + "cgroup_net_cls_parent: 'NSJAILTEST3'", + ), + ( + ("cgroup_cpu_ms_per_sec: 800",), + "cgroup_cpu_mount: '/sys/fs/cgroup/cpu'", + "cgroup_cpu_parent: 'NSJAILTEST4'", + ), + ) + + # protobuf doesn't parse correctly when NamedTemporaryFile is used directly. + with tempfile.TemporaryDirectory() as directory: + for values, mount, parent in cases: + for lines in product(values, (mount, ""), (parent, "")): + with self.subTest(config=lines): + config_path = str(Path(directory, "config.cfg")) + with open(config_path, "w", encoding="utf8") as f: + f.write("\n".join(lines + (config_base,))) + + nsjail = NsJail(config_path=config_path) + + result = nsjail.python3("") + + self.assertNotEqual(result.returncode, 255) |