diff options
| -rw-r--r-- | snekbox/nsjail.py | 43 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 6 | 
2 files changed, 27 insertions, 22 deletions
| diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 1f99501..5616733 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -25,10 +25,6 @@ LOG_PATTERN = re.compile(  NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail")  NSJAIL_CFG = os.getenv("NSJAIL_CFG", "./config/snekbox.cfg") -# Limit of stdout bytes we consume before terminating nsjail -OUTPUT_MAX = 1_000_000  # 1 MB -READ_CHUNK_SIZE = 10_000  # chars -  class NsJail:      """ @@ -37,33 +33,43 @@ class NsJail:      See config/snekbox.cfg for the default NsJail configuration.      """ -    def __init__(self, nsjail_binary: str = NSJAIL_PATH): -        self.nsjail_binary = nsjail_binary -        self.config = self._read_config() +    def __init__( +        self, +        nsjail_path: str = NSJAIL_PATH, +        config_path: str = NSJAIL_CFG, +        max_output_size: int = 1_000_000, +        read_chunk_size: int = 10_000, +    ): +        self.nsjail_path = nsjail_path +        self.config_path = config_path +        self.max_output_size = max_output_size +        self.read_chunk_size = read_chunk_size + +        self.config = self._read_config(config_path)          self.cgroup_version = utils.cgroup.init(self.config)          self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version)          log.info(f"Assuming cgroup version {self.cgroup_version}.")      @staticmethod -    def _read_config() -> NsJailConfig: -        """Read the NsJail config at `NSJAIL_CFG` and return a protobuf Message object.""" +    def _read_config(config_path: str) -> NsJailConfig: +        """Read the NsJail config at `config_path` and return a protobuf Message object."""          config = NsJailConfig()          try: -            with open(NSJAIL_CFG, encoding="utf-8") as f: +            with open(config_path, encoding="utf-8") as f:                  config_text = f.read()          except FileNotFoundError: -            log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be found.") +            log.fatal(f"The NsJail config at {config_path!r} could not be found.")              sys.exit(1)          except OSError as e: -            log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be read.", exc_info=e) +            log.fatal(f"The NsJail config at {config_path!r} could not be read.", exc_info=e)              sys.exit(1)          try:              text_format.Parse(config_text, config)          except text_format.ParseError as e: -            log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be parsed.", exc_info=e) +            log.fatal(f"The NsJail config at {config_path!r} could not be parsed.", exc_info=e)              sys.exit(1)          return config @@ -94,8 +100,7 @@ class NsJail:                  # Treat fatal as error.                  log.error(msg) -    @staticmethod -    def _consume_stdout(nsjail: subprocess.Popen) -> str: +    def _consume_stdout(self, nsjail: subprocess.Popen) -> str:          """          Consume STDOUT, stopping when the output limit is reached or NsJail has exited. @@ -114,11 +119,11 @@ class NsJail:          with nsjail:              # We'll consume STDOUT as long as the NsJail subprocess is running.              while nsjail.poll() is None: -                chars = nsjail.stdout.read(READ_CHUNK_SIZE) +                chars = nsjail.stdout.read(self.read_chunk_size)                  output_size += sys.getsizeof(chars)                  output.append(chars) -                if output_size > OUTPUT_MAX: +                if output_size > self.max_output_size:                      # Terminate the NsJail subprocess with SIGTERM.                      # This in turn reaps and kills children with SIGKILL.                      log.info("Output exceeded the output limit, sending SIGTERM to NsJail.") @@ -153,9 +158,9 @@ class NsJail:          with NamedTemporaryFile() as nsj_log:              args = ( -                self.nsjail_binary, +                self.nsjail_path,                  "--config", -                NSJAIL_CFG, +                self.config_path,                  "--log",                  nsj_log.name,                  *nsjail_args, diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index a3632e6..dc30dfa 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -5,7 +5,7 @@ import unittest  import unittest.mock  from textwrap import dedent -from snekbox.nsjail import OUTPUT_MAX, READ_CHUNK_SIZE, NsJail +from snekbox.nsjail import NsJail  class NsJailTests(unittest.TestCase): @@ -255,8 +255,8 @@ class NsJailTests(unittest.TestCase):          self.assertEqual(result.returncode, 143)      def test_large_output_is_truncated(self): -        chunk = "a" * READ_CHUNK_SIZE -        expected_chunks = OUTPUT_MAX // sys.getsizeof(chunk) + 1 +        chunk = "a" * self.nsjail.read_chunk_size +        expected_chunks = self.nsjail.max_output_size // sys.getsizeof(chunk) + 1          nsjail_subprocess = unittest.mock.MagicMock() | 
