diff options
| -rw-r--r-- | snekbox/nsjail.py | 21 | ||||
| -rw-r--r-- | snekbox/utils/cgroup.py | 70 | 
2 files changed, 27 insertions, 64 deletions
| diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 2d84d3b..79d33ba 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -4,7 +4,6 @@ import re  import subprocess  import sys  import textwrap -from pathlib import Path  from subprocess import CompletedProcess  from tempfile import NamedTemporaryFile  from typing import Iterable @@ -40,10 +39,7 @@ class NsJail:      def __init__(self, nsjail_binary: str = NSJAIL_PATH):          self.nsjail_binary = nsjail_binary          self.config = self._read_config() -        self.cgroup_version = utils.cgroup.get_version(self.config) - -        if self.cgroup_version == 2: -            utils.cgroup.init_v2(self.config) +        self.cgroup_version = utils.cgroup.init(self.config)          log.info(f"Assuming cgroup version {self.cgroup_version}.") @@ -149,15 +145,7 @@ class NsJail:          `py_args` are arguments to pass to the Python subprocess before the code,          which is the last argument. By default, it's "-c", which executes the code given.          """ -        cgroup = None -        if self.cgroup_version == 1: -            cgroup = utils.cgroup.create_dynamic(self.config) -            nsjail_args = ( -                "--cgroup_mem_parent", cgroup, -                "--cgroup_pids_parent", cgroup, -                *nsjail_args, -            ) -        else: +        if self.cgroup_version == 2:              nsjail_args = ("--use_cgroupv2", *nsjail_args)          with NamedTemporaryFile() as nsj_log: @@ -209,9 +197,4 @@ class NsJail:          log.info(f"nsjail return code: {returncode}") -        # Remove the dynamically created cgroups once we're done -        if self.cgroup_version == 1: -            Path(self.config.cgroup_mem_mount, cgroup).rmdir() -            Path(self.config.cgroup_pids_mount, cgroup).rmdir() -          return CompletedProcess(args, returncode, output, None) diff --git a/snekbox/utils/cgroup.py b/snekbox/utils/cgroup.py index 3df269e..3e12406 100644 --- a/snekbox/utils/cgroup.py +++ b/snekbox/utils/cgroup.py @@ -1,5 +1,4 @@  import logging -import uuid  from pathlib import Path  from snekbox.config_pb2 import NsJailConfig @@ -7,50 +6,6 @@ from snekbox.config_pb2 import NsJailConfig  log = logging.getLogger(__name__) -def create_dynamic(config: NsJailConfig) -> str: -    """ -    Create a PID and memory cgroup for NsJail to use as the parent cgroup. - -    Returns the name of the cgroup, located in the cgroup root. - -    NsJail doesn't do this automatically because it requires privileges NsJail usually doesn't -    have. - -    Disables memory swapping. -    """ -    # Pick a name for the cgroup -    cgroup = "snekbox-" + str(uuid.uuid4()) - -    pids = Path(config.cgroup_pids_mount, cgroup) -    mem = Path(config.cgroup_mem_mount, cgroup) -    mem_max = str(config.cgroup_mem_max) - -    pids.mkdir(parents=True, exist_ok=True) -    mem.mkdir(parents=True, exist_ok=True) - -    # Swap limit cannot be set to a value lower than memory.limit_in_bytes. -    # Therefore, this must be set before the swap limit. -    # -    # Since child cgroups are dynamically created, the swap limit has to be set on the parent -    # instead so that children inherit it. Given the swap's dependency on the memory limit, -    # the memory limit must also be set on the parent. NsJail only sets the memory limit for -    # child cgroups, not the parent. -    (mem / "memory.limit_in_bytes").write_text(mem_max) - -    try: -        # Swap limit is specified as the sum of the memory and swap limits. -        # Therefore, setting it equal to the memory limit effectively disables swapping. -        (mem / "memory.memsw.limit_in_bytes").write_text(mem_max) -    except PermissionError: -        log.warning( -            "Failed to set the memory swap limit for the cgroup. " -            "This is probably because CONFIG_MEMCG_SWAP or CONFIG_MEMCG_SWAP_ENABLED is unset. " -            "Please ensure swap memory is disabled on the system." -        ) - -    return cgroup - -  def get_version(config: NsJailConfig) -> int:      """      Examine the filesystem and return the guessed cgroup version. @@ -95,6 +50,31 @@ def get_version(config: NsJailConfig) -> int:          return config_version +def init(config: NsJailConfig) -> int: +    """Determine the cgroup version, initialise the cgroups for NsJail, and return the version.""" +    version = get_version(config) +    if version == 1: +        init_v1(config) +    else: +        init_v2(config) + +    return version + + +def init_v1(config: NsJailConfig) -> None: +    """ +    Create a PID and memory cgroup for NsJail to use as the parent cgroup for each controller. + +    NsJail doesn't do this automatically because it requires privileges NsJail usually doesn't +    have. +    """ +    pids = Path(config.cgroup_pids_mount, config.cgroup_pids_parent) +    mem = Path(config.cgroup_mem_mount, config.cgroup_mem_parent) + +    pids.mkdir(parents=True, exist_ok=True) +    mem.mkdir(parents=True, exist_ok=True) + +  def init_v2(config: NsJailConfig) -> None:      """Ensure cgroupv2 children have controllers enabled."""      cgroup_mount = Path(config.cgroupv2_mount) | 
