aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--snekbox/nsjail.py10
-rw-r--r--snekbox/utils/timed.py9
-rw-r--r--tests/test_nsjail.py34
3 files changed, 40 insertions, 13 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index 006ff98..fb4ff1d 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -241,10 +241,12 @@ class NsJail:
(fs, self.files_limit, self.files_pattern),
timeout=self.files_timeout,
)
- log.info(f"Found {len(attachments)} attachments.")
- except TimeoutError as err:
- log.info(f"Exceeded time limit in parsing attachments: {err}")
- return EvalResult(args, returncode, f"AttachmentError: {err}")
+ log.info(f"Found {len(attachments)} files.")
+ except TimeoutError as e:
+ log.warning(f"Exceeded time limit while parsing attachments: {e}")
+ return EvalResult(
+ args, None, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
diff --git a/snekbox/utils/timed.py b/snekbox/utils/timed.py
index e1602b5..4acfa0d 100644
--- a/snekbox/utils/timed.py
+++ b/snekbox/utils/timed.py
@@ -1,6 +1,6 @@
"""Calling functions with time limits."""
+import multiprocessing
from collections.abc import Callable, Iterable, Mapping
-from multiprocessing import Pool
from typing import Any, TypeVar
_T = TypeVar("_T")
@@ -23,6 +23,9 @@ def timed(
"""
if kwds is None:
kwds = {}
- with Pool(1) as pool:
+ with multiprocessing.Pool(1) as pool:
result = pool.apply_async(func, args, kwds)
- return result.get(timeout)
+ try:
+ return result.get(timeout)
+ except multiprocessing.TimeoutError as e:
+ raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index 493e1f1..dca0a8f 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -17,19 +17,17 @@ class NsJailTests(unittest.TestCase):
def setUp(self):
super().setUp()
- self.nsjail = NsJail()
- # Set a lower memfs size limit so tests don't exceed time limit
- self.nsjail.memfs_instance_size = 2 * 1024 * 1024 # 2MiB
-
+ # Specify lower limits for unit tests to complete within time limits
+ self.nsjail = NsJail(memfs_instance_size=2 * 1024 * 1024)
self.logger = logging.getLogger("snekbox.nsjail")
self.logger.setLevel(logging.WARNING)
def eval_code(self, code: str):
return self.nsjail.python3(["-c", code])
- def eval_file(self, code: str, name: str = "test.py"):
+ def eval_file(self, code: str, name: str = "test.py", **kwargs):
file = FileAttachment(name, code)
- return self.nsjail.python3([name], [file])
+ return self.nsjail.python3([name], [file], **kwargs)
def test_print_returns_0(self):
for fn in (self.eval_code, self.eval_file):
@@ -185,6 +183,30 @@ class NsJailTests(unittest.TestCase):
self.assertIn("Resource temporarily unavailable", result.stdout)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_timeout(self):
+ code = dedent(
+ """
+ import os
+ data = "a" * 1024
+ size = 32 * 1024 * 1024
+
+ with open("src", "w") as f:
+ for _ in range((size // 1024) - 5):
+ f.write(data)
+
+ for i in range(100):
+ os.symlink("src", f"output{i}")
+ """
+ ).strip()
+
+ nsjail = NsJail(memfs_instance_size=48 * 1024 * 1024, files_timeout=1)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, None)
+ self.assertEqual(
+ result.stdout, "TimeoutError: Exceeded time limit while parsing attachments"
+ )
+ self.assertEqual(result.stderr, None)
+
def test_sigsegv_returns_139(self): # In honour of Juan.
code = dedent(
"""