aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Mark <[email protected]>2022-06-04 18:26:21 -0700
committerGravatar Mark <[email protected]>2022-06-04 23:22:25 -0700
commite093f20bad446bb6023ef243fe867d5f2f7b4334 (patch)
treeb7d80c68ce3931e8cd7464e0f3251bf64b242b5c
parentFix force-exclude option for black (diff)
Add config path & output size args to NsJail class
-rw-r--r--snekbox/nsjail.py43
-rw-r--r--tests/test_nsjail.py6
2 files changed, 27 insertions, 22 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index 1f99501..5616733 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -25,10 +25,6 @@ LOG_PATTERN = re.compile(
NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail")
NSJAIL_CFG = os.getenv("NSJAIL_CFG", "./config/snekbox.cfg")
-# Limit of stdout bytes we consume before terminating nsjail
-OUTPUT_MAX = 1_000_000 # 1 MB
-READ_CHUNK_SIZE = 10_000 # chars
-
class NsJail:
"""
@@ -37,33 +33,43 @@ class NsJail:
See config/snekbox.cfg for the default NsJail configuration.
"""
- def __init__(self, nsjail_binary: str = NSJAIL_PATH):
- self.nsjail_binary = nsjail_binary
- self.config = self._read_config()
+ def __init__(
+ self,
+ nsjail_path: str = NSJAIL_PATH,
+ config_path: str = NSJAIL_CFG,
+ max_output_size: int = 1_000_000,
+ read_chunk_size: int = 10_000,
+ ):
+ self.nsjail_path = nsjail_path
+ self.config_path = config_path
+ self.max_output_size = max_output_size
+ self.read_chunk_size = read_chunk_size
+
+ self.config = self._read_config(config_path)
self.cgroup_version = utils.cgroup.init(self.config)
self.ignore_swap_limits = utils.swap.should_ignore_limit(self.config, self.cgroup_version)
log.info(f"Assuming cgroup version {self.cgroup_version}.")
@staticmethod
- def _read_config() -> NsJailConfig:
- """Read the NsJail config at `NSJAIL_CFG` and return a protobuf Message object."""
+ def _read_config(config_path: str) -> NsJailConfig:
+ """Read the NsJail config at `config_path` and return a protobuf Message object."""
config = NsJailConfig()
try:
- with open(NSJAIL_CFG, encoding="utf-8") as f:
+ with open(config_path, encoding="utf-8") as f:
config_text = f.read()
except FileNotFoundError:
- log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be found.")
+ log.fatal(f"The NsJail config at {config_path!r} could not be found.")
sys.exit(1)
except OSError as e:
- log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be read.", exc_info=e)
+ log.fatal(f"The NsJail config at {config_path!r} could not be read.", exc_info=e)
sys.exit(1)
try:
text_format.Parse(config_text, config)
except text_format.ParseError as e:
- log.fatal(f"The NsJail config at {NSJAIL_CFG!r} could not be parsed.", exc_info=e)
+ log.fatal(f"The NsJail config at {config_path!r} could not be parsed.", exc_info=e)
sys.exit(1)
return config
@@ -94,8 +100,7 @@ class NsJail:
# Treat fatal as error.
log.error(msg)
- @staticmethod
- def _consume_stdout(nsjail: subprocess.Popen) -> str:
+ def _consume_stdout(self, nsjail: subprocess.Popen) -> str:
"""
Consume STDOUT, stopping when the output limit is reached or NsJail has exited.
@@ -114,11 +119,11 @@ class NsJail:
with nsjail:
# We'll consume STDOUT as long as the NsJail subprocess is running.
while nsjail.poll() is None:
- chars = nsjail.stdout.read(READ_CHUNK_SIZE)
+ chars = nsjail.stdout.read(self.read_chunk_size)
output_size += sys.getsizeof(chars)
output.append(chars)
- if output_size > OUTPUT_MAX:
+ if output_size > self.max_output_size:
# Terminate the NsJail subprocess with SIGTERM.
# This in turn reaps and kills children with SIGKILL.
log.info("Output exceeded the output limit, sending SIGTERM to NsJail.")
@@ -153,9 +158,9 @@ class NsJail:
with NamedTemporaryFile() as nsj_log:
args = (
- self.nsjail_binary,
+ self.nsjail_path,
"--config",
- NSJAIL_CFG,
+ self.config_path,
"--log",
nsj_log.name,
*nsjail_args,
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index a3632e6..dc30dfa 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -5,7 +5,7 @@ import unittest
import unittest.mock
from textwrap import dedent
-from snekbox.nsjail import OUTPUT_MAX, READ_CHUNK_SIZE, NsJail
+from snekbox.nsjail import NsJail
class NsJailTests(unittest.TestCase):
@@ -255,8 +255,8 @@ class NsJailTests(unittest.TestCase):
self.assertEqual(result.returncode, 143)
def test_large_output_is_truncated(self):
- chunk = "a" * READ_CHUNK_SIZE
- expected_chunks = OUTPUT_MAX // sys.getsizeof(chunk) + 1
+ chunk = "a" * self.nsjail.read_chunk_size
+ expected_chunks = self.nsjail.max_output_size // sys.getsizeof(chunk) + 1
nsjail_subprocess = unittest.mock.MagicMock()