diff options
author | 2022-06-04 18:26:21 -0700 | |
---|---|---|
committer | 2022-06-04 23:22:25 -0700 | |
commit | e093f20bad446bb6023ef243fe867d5f2f7b4334 (patch) | |
tree | b7d80c68ce3931e8cd7464e0f3251bf64b242b5c | |
parent | Fix force-exclude option for black (diff) |
Add config path & output size args to NsJail class
-rw-r--r-- | snekbox/nsjail.py | 43 | ||||
-rw-r--r-- | tests/test_nsjail.py | 6 |
2 files changed, 27 insertions, 22 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 1f99501..5616733 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -25,10 +25,6 @@ LOG_PATTERN = re.compile( NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail") NSJAIL_CFG = os.getenv("NSJAIL_CFG", "./config/snekbox.cfg") -# Limit of stdout bytes we consume before terminating nsjail -OUTPUT_MAX = 1_000_000 # 1 MB -READ_CHUNK_SIZE = 10_000 # chars - class NsJail: """ @@ -37,33 +33,43 @@ class NsJail: See config/snekbox.cfg for the default NsJail configuration. """ - def __init__(self, nsjail_binary: str = NSJAIL_PATH): - self.nsjail_binary = nsjail_binary - self.config = self._read_config() + def __init__( + self, + nsjail_path: str = NSJAIL_PATH, + config_path: str = NSJAIL_CFG, + max_output_size: int = 1_000_000, + read_chunk_size: int = 10_000, + ): + self.nsjail_path = nsjail_path + self.config_path = config_path + self.max_output_size = max_output_size + self.read_chunk_size = read_chunk_size + + self.config = self._read_config(config_path) self.cgroup_version = utils.cgroup.init(self.config) self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version) log.info(f"Assuming cgroup version {self.cgroup_version}.") @staticmethod - def _read_config() -> NsJailConfig: - """Read the NsJail config at `NSJAIL_CFG` and return a protobuf Message object.""" + def _read_config(config_path: str) -> NsJailConfig: + """Read the NsJail config at `config_path` and return a protobuf Message object.""" config = NsJailConfig() try: - with open(NSJAIL_CFG, encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: config_text = f.read() except FileNotFoundError: - log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be found.") + log.fatal(f"The NsJail config at {config_path!r} could not be found.") sys.exit(1) except OSError as e: - log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be read.", exc_info=e) + log.fatal(f"The NsJail config at {config_path!r} could not be read.", exc_info=e) sys.exit(1) try: text_format.Parse(config_text, config) except text_format.ParseError as e: - log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be parsed.", exc_info=e) + log.fatal(f"The NsJail config at {config_path!r} could not be parsed.", exc_info=e) sys.exit(1) return config @@ -94,8 +100,7 @@ class NsJail: # Treat fatal as error. log.error(msg) - @staticmethod - def _consume_stdout(nsjail: subprocess.Popen) -> str: + def _consume_stdout(self, nsjail: subprocess.Popen) -> str: """ Consume STDOUT, stopping when the output limit is reached or NsJail has exited. @@ -114,11 +119,11 @@ class NsJail: with nsjail: # We'll consume STDOUT as long as the NsJail subprocess is running. while nsjail.poll() is None: - chars = nsjail.stdout.read(READ_CHUNK_SIZE) + chars = nsjail.stdout.read(self.read_chunk_size) output_size += sys.getsizeof(chars) output.append(chars) - if output_size > OUTPUT_MAX: + if output_size > self.max_output_size: # Terminate the NsJail subprocess with SIGTERM. # This in turn reaps and kills children with SIGKILL. log.info("Output exceeded the output limit, sending SIGTERM to NsJail.") @@ -153,9 +158,9 @@ class NsJail: with NamedTemporaryFile() as nsj_log: args = ( - self.nsjail_binary, + self.nsjail_path, "--config", - NSJAIL_CFG, + self.config_path, "--log", nsj_log.name, *nsjail_args, diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index a3632e6..dc30dfa 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -5,7 +5,7 @@ import unittest import unittest.mock from textwrap import dedent -from snekbox.nsjail import OUTPUT_MAX, READ_CHUNK_SIZE, NsJail +from snekbox.nsjail import NsJail class NsJailTests(unittest.TestCase): @@ -255,8 +255,8 @@ class NsJailTests(unittest.TestCase): self.assertEqual(result.returncode, 143) def test_large_output_is_truncated(self): - chunk = "a" * READ_CHUNK_SIZE - expected_chunks = OUTPUT_MAX // sys.getsizeof(chunk) + 1 + chunk = "a" * self.nsjail.read_chunk_size + expected_chunks = self.nsjail.max_output_size // sys.getsizeof(chunk) + 1 nsjail_subprocess = unittest.mock.MagicMock() |