aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ionite34 <[email protected]>2022-11-22 13:07:15 -0500
committerGravatar ionite34 <[email protected]>2022-11-22 13:07:15 -0500
commit621c835f28cd3fcd1e5c3cccae9b5057647f9fa1 (patch)
tree50d9dd6bd6bf84c8c201995adfa59ed0f1832c33
parentReadded lib64 (diff)
Fixed leading empty filter for py_args
-rw-r--r--snekbox/nsjail.py26
-rw-r--r--tests/test_nsjail.py17
2 files changed, 34 insertions, 9 deletions
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index fb4ff1d..4050fae 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -2,8 +2,9 @@ import logging
import re
import subprocess
import sys
+from collections.abc import Generator
from tempfile import NamedTemporaryFile
-from typing import Iterable
+from typing import Iterable, TypeVar
from google.protobuf import text_format
@@ -19,18 +20,35 @@ from snekbox.utils.timed import timed
log = logging.getLogger(__name__)
+T = TypeVar("T")
+
# [level][timestamp][PID]? function_signature:line_no? message
LOG_PATTERN = re.compile(
r"\[(?P<level>(I)|[DWEF])\]\[.+?\](?(2)|(?P<func>\[\d+\] .+?:\d+ )) ?(?P<msg>.+)"
)
+def iter_lstrip(iterable: Iterable[T]) -> Generator[T, None, None]:
+ """Removes leading falsy objects from an iterable."""
+ it = iter(iterable)
+ for item in it:
+ if item:
+ yield item
+ break
+ yield from it
+
+
def parse_files(
fs: MemFS,
files_limit: int,
files_pattern: str,
) -> list[FileAttachment]:
- """Parse files in a MemFS."""
+ """
+ Parse files in a MemFS.
+
+ Returns:
+ List of FileAttachments sorted lexically by path name.
+ """
return sorted(fs.attachments(files_limit, files_pattern), key=lambda file: file.path)
@@ -201,9 +219,9 @@ class NsJail:
"--",
self.config.exec_bin.path,
*self.config.exec_bin.arg,
- # Filter out empty strings at start of iterable
+ # Filter out empty strings at start of py_args
# (causes issues with python cli)
- *(arg for i, arg in enumerate(py_args) if (arg or i > 0)),
+ *iter_lstrip(py_args),
]
# Write files if any
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index dca0a8f..da2afea 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -361,11 +361,18 @@ class NsJailTests(unittest.TestCase):
self.assertEqual(result.args[end - len(args) : end], args)
def test_py_args(self):
- args = ["-m", "timeit"]
- result = self.nsjail.python3(args)
-
- self.assertEqual(result.returncode, 0)
- self.assertEqual(result.args[-2:], args)
+ expected = ["-m", "timeit"]
+ args = [
+ ["", "-m", "timeit"],
+ ["", "", "-m", "timeit"],
+ ["", "", "", "-m", "timeit"],
+ ]
+ # Leading empty strings should be removed
+ for case in args:
+ with self.subTest(args=args):
+ result = self.nsjail.python3(case)
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(result.args[-2:], expected)
class NsJailArgsTests(unittest.TestCase):