aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config/snekbox.cfg1
-rw-r--r--snekbox/nsjail.py33
-rw-r--r--tests/test_nsjail.py9
3 files changed, 38 insertions, 5 deletions
diff --git a/config/snekbox.cfg b/config/snekbox.cfg
index 5b47459..27caf27 100644
--- a/config/snekbox.cfg
+++ b/config/snekbox.cfg
@@ -18,7 +18,6 @@ envar: "NUMEXPR_NUM_THREADS=1"
keep_caps: false
rlimit_as: 700
-rlimit_fsize: 10
clone_newnet: true
clone_newuser: true
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index cafde6d..b792ba9 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -27,6 +27,9 @@ NSJAIL_PATH = os.getenv("NSJAIL_PATH", "/usr/sbin/nsjail")
NSJAIL_CFG = os.getenv("NSJAIL_CFG", "./config/snekbox.cfg")
MEM_MAX = 52428800
+# Limit of stdout bytes we consume before terminating nsjail
+OUTPUT_MAX = 1_000_000 # 1 MB
+
class NsJail:
"""
@@ -124,7 +127,7 @@ class NsJail:
log.info(msg)
try:
- result = subprocess.run(
+ nsjail = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
@@ -133,11 +136,33 @@ class NsJail:
except ValueError:
return CompletedProcess(args, None, "ValueError: embedded null byte", None)
+ output_size = 0
+ output = []
+
+ # We'll consume STDOUT as long as the subprocess is running
+ while nsjail.poll() is None:
+ # Read 100 characters from the STDOUT stream
+ chars = nsjail.stdout.read(100)
+ chars_size = sys.getsizeof(chars)
+
+ # Check if these characters take us over the output limit
+ if output_size + chars_size > OUTPUT_MAX:
+ # Ask nsjail to terminate itself using SIGTERM
+ nsjail.terminate()
+ break
+
+ output_size += chars_size
+ output.append(chars)
+
+ # Ensure that we wait for the nsjail process to terminate
+ nsjail.wait()
+
log_lines = nsj_log.read().decode("utf-8").splitlines()
- if not log_lines and result.returncode == 255:
+ if not log_lines and nsjail.returncode == 255:
# NsJail probably failed to parse arguments so log output will still be in stdout
- log_lines = result.stdout.splitlines()
+ log_lines = "".join(output).splitlines()
self._parse_log(log_lines)
- return result
+ log.info(f"nsjail return code: {nsjail.returncode}")
+ return CompletedProcess(args, nsjail.returncode, "".join(output), None)
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 0b755b2..852be4b 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -174,3 +174,12 @@ class NsJailTests(unittest.TestCase):
msg="stdout does not come before stderr"
)
self.assertEqual(result.stderr, None)
+
+ def test_stdout_flood_results_in_graceful_sigterm(self):
+ stdout_flood = dedent("""
+ while True:
+ print('abcdefghij')
+ """).strip()
+
+ result = self.nsjail.python3(stdout_flood)
+ self.assertEqual(result.returncode, 143)