diff options
author | 2021-12-21 13:35:57 -0800 | |
---|---|---|
committer | 2021-12-21 13:35:57 -0800 | |
commit | 3c52aeffcff841c74c18d7ee936144d3878b7d89 (patch) | |
tree | cf627e8bef0dfac680dbedeccd9401e5d77f36b3 | |
parent | Show a warning if use_cgroupv2 is true but only a v1 fs is detected (diff) |
Add a test for #83
-rw-r--r-- | tests/gunicorn_utils.py | 80 | ||||
-rw-r--r-- | tests/test_integration.py | 43 |
2 files changed, 123 insertions, 0 deletions
diff --git a/tests/gunicorn_utils.py b/tests/gunicorn_utils.py new file mode 100644 index 0000000..f2d9b6d --- /dev/null +++ b/tests/gunicorn_utils.py @@ -0,0 +1,80 @@ +import concurrent.futures +import contextlib +import multiprocessing +from typing import Iterator + +from gunicorn.app.base import Application + + +class _StandaloneApplication(Application): + def __init__(self, config_path: str = None, **kwargs): + self.config_path = config_path + self.options = kwargs + + super().__init__() + + def init(self, parser, opts, args): + pass + + def load(self): + from snekbox.api.app import application + return application + + def load_config(self): + for key, value in self.options.items(): + if key in self.cfg.settings and value is not None: + self.cfg.set(key.lower(), value) + + if self.config_path: + self.load_config_from_file(self.config_path) + + +def _proc_target(config_path: str, event: multiprocessing.Event, **kwargs) -> None: + """Run a Gunicorn app with the given config and set `event` when Gunicorn is ready.""" + def when_ready(_): + event.set() + + app = _StandaloneApplication(config_path, when_ready=when_ready, **kwargs) + + import logging + logging.disable(logging.INFO) + + app.run() + + +def run_gunicorn(config_path: str = "config/gunicorn.conf.py", **kwargs) -> Iterator[None]: + """ + Run the Snekbox app through separate Gunicorn process. Use as a context manager. + + `config_path` is the path to the Gunicorn config to use. + Additional kwargs are interpreted as Gunicorn settings. + + Raise RuntimeError if Gunicorn terminates before it is ready. + Raise TimeoutError if Gunicorn isn't ready after 60 seconds. + """ + event = multiprocessing.Event() + proc = multiprocessing.Process(target=_proc_target, args=(config_path, event), kwargs=kwargs) + + try: + proc.start() + + # Wait 60 seconds for Gunicorn to be ready, but exit early if Gunicorn fails. + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + concurrent.futures.wait( + [executor.submit(proc.join), executor.submit(event.wait)], + timeout=60, + return_when=concurrent.futures.FIRST_COMPLETED + ) + # Can't use the context manager cause wait=False needs to be set. + executor.shutdown(wait=False, cancel_futures=True) + + if proc.is_alive(): + if not event.is_set(): + raise TimeoutError("Timed out waiting for Gunicorn to be ready.") + else: + raise RuntimeError(f"Gunicorn terminated unexpectedly with code {proc.exitcode}.") + + yield + finally: + proc.terminate() diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..b76b005 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,43 @@ +import json +import unittest +import urllib.request +from multiprocessing.dummy import Pool + +from tests.gunicorn_utils import run_gunicorn + + +def run_code_in_snekbox(code: str) -> tuple[str, int]: + body = {"input": code} + json_data = json.dumps(body).encode("utf-8") + + req = urllib.request.Request("http://localhost:8060/eval") + req.add_header("Content-Type", "application/json; charset=utf-8") + req.add_header("Content-Length", str(len(json_data))) + + with urllib.request.urlopen(req, json_data, timeout=30) as response: + response_data = response.read().decode("utf-8") + + return response_data, response.status + + +class IntegrationTests(unittest.TestCase): + + def test_memory_limit_separate_per_process(self): + """ + Each NsJail process should have its own memory limit. + + The memory used by one process should not contribute to the memory cap of other processes. + See https://github.com/python-discord/snekbox/issues/83 + """ + with run_gunicorn(): + code = "import time; ' ' * 33000000; time.sleep(0.1)" + processes = 3 + + args = [code] * processes + with Pool(processes) as p: + results = p.map(run_code_in_snekbox, args) + + responses, statuses = zip(*results) + + self.assertTrue(all(status == 200 for status in statuses)) + self.assertTrue(all(json.loads(response)["returncode"] == 0 for response in responses)) |