diff options
author | 2022-11-22 13:07:15 -0500 | |
---|---|---|
committer | 2022-11-22 13:07:15 -0500 | |
commit | 621c835f28cd3fcd1e5c3cccae9b5057647f9fa1 (patch) | |
tree | 50d9dd6bd6bf84c8c201995adfa59ed0f1832c33 | |
parent | Readded lib64 (diff) |
Fixed leading empty filter for py_args
-rw-r--r-- | snekbox/nsjail.py | 26 | ||||
-rw-r--r-- | tests/test_nsjail.py | 17 |
2 files changed, 34 insertions, 9 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index fb4ff1d..4050fae 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -2,8 +2,9 @@ import logging import re import subprocess import sys +from collections.abc import Generator from tempfile import NamedTemporaryFile -from typing import Iterable +from typing import Iterable, TypeVar from google.protobuf import text_format @@ -19,18 +20,35 @@ from snekbox.utils.timed import timed log = logging.getLogger(__name__) +T = TypeVar("T") + # [level][timestamp][PID]? function_signature:line_no? message LOG_PATTERN = re.compile( r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)" ) +def iter_lstrip(iterable: Iterable[T]) -> Generator[T, None, None]: + """Removes leading falsy objects from an iterable.""" + it = iter(iterable) + for item in it: + if item: + yield item + break + yield from it + + def parse_files( fs: MemFS, files_limit: int, files_pattern: str, ) -> list[FileAttachment]: - """Parse files in a MemFS.""" + """ + Parse files in a MemFS. + + Returns: + List of FileAttachments sorted lexically by path name. + """ return sorted(fs.attachments(files_limit, files_pattern), key=lambda file: file.path) @@ -201,9 +219,9 @@ class NsJail: "--", self.config.exec_bin.path, *self.config.exec_bin.arg, - # Filter out empty strings at start of iterable + # Filter out empty strings at start of py_args # (causes issues with python cli) - *(arg for i, arg in enumerate(py_args) if (arg or i > 0)), + *iter_lstrip(py_args), ] # Write files if any diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index dca0a8f..da2afea 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -361,11 +361,18 @@ class NsJailTests(unittest.TestCase): self.assertEqual(result.args[end - len(args) : end], args) def test_py_args(self): - args = ["-m", "timeit"] - result = self.nsjail.python3(args) - - self.assertEqual(result.returncode, 0) - self.assertEqual(result.args[-2:], args) + expected = ["-m", "timeit"] + args = [ + ["", "-m", "timeit"], + ["", "", "-m", "timeit"], + ["", "", "", "-m", "timeit"], + ] + # Leading empty strings should be removed + for case in args: + with self.subTest(args=args): + result = self.nsjail.python3(case) + self.assertEqual(result.returncode, 0) + self.assertEqual(result.args[-2:], expected) class NsJailArgsTests(unittest.TestCase): |