aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--snekbox/nsjail.py61
1 files changed, 39 insertions, 22 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index b792ba9..80d9fd5 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -104,6 +104,42 @@ class NsJail:
# Treat fatal as error.
log.error(msg)
+ @staticmethod
+ def _consume_stdout(nsjail: subprocess.Popen) -> str:
+ """
+ Consume STDOUT, stopping when the output limit is reached or NsJail has exited.
+
+ The aim of this function is to limit the size of the output received from
+ NsJail to prevent container from claiming too much memory. If the output
+ received from STDOUT goes over the OUTPUT_MAX limit, the NsJail subprocess
+ is asked to terminate with a SIGTERM.
+
+ Once the subprocess has exited, either naturally or because it was terminated,
+ the output up to that point is returned as a string.
+ """
+ output_size = 0
+ output = []
+
+ # We'll consume STDOUT as long as the NsJail subprocess is running
+ while nsjail.poll() is None:
+ # Read 100 characters from the STDOUT stream
+ chars = nsjail.stdout.read(100)
+ chars_size = sys.getsizeof(chars)
+
+ # Check if these characters take us over the output limit
+ if output_size + chars_size > OUTPUT_MAX:
+ # Ask nsjail to terminate itself using SIGTERM
+ nsjail.terminate()
+ break
+
+ output_size += chars_size
+ output.append(chars)
+
+ # Ensure that we wait for the nsjail process to terminate
+ nsjail.wait()
+
+ return "".join(output)
+
def python3(self, code: str) -> CompletedProcess:
"""Execute Python 3 code in an isolated environment and return the completed process."""
with NamedTemporaryFile() as nsj_log:
@@ -136,33 +172,14 @@ class NsJail:
except ValueError:
return CompletedProcess(args, None, "ValueError: embedded null byte", None)
- output_size = 0
- output = []
-
- # We'll consume STDOUT as long as the subprocess is running
- while nsjail.poll() is None:
- # Read 100 characters from the STDOUT stream
- chars = nsjail.stdout.read(100)
- chars_size = sys.getsizeof(chars)
-
- # Check if these characters take us over the output limit
- if output_size + chars_size > OUTPUT_MAX:
- # Ask nsjail to terminate itself using SIGTERM
- nsjail.terminate()
- break
-
- output_size += chars_size
- output.append(chars)
-
- # Ensure that we wait for the nsjail process to terminate
- nsjail.wait()
+ output = self._consume_stdout(nsjail)
log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and nsjail.returncode == 255:
# NsJail probably failed to parse arguments so log output will still be in stdout
- log_lines = "".join(output).splitlines()
+ log_lines = output.splitlines()
self._parse_log(log_lines)
log.info(f"nsjail return code: {nsjail.returncode}")
- return CompletedProcess(args, nsjail.returncode, "".join(output), None)
+ return CompletedProcess(args, nsjail.returncode, output, None)