aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--requirements/pip-tools.in4
-rw-r--r--requirements/pip-tools.pip2
-rw-r--r--snekbox/memfs.py18
-rw-r--r--tests/test_nsjail.py59
4 files changed, 76 insertions, 7 deletions
diff --git a/requirements/pip-tools.in b/requirements/pip-tools.in
index e459df9..29d8d31 100644
--- a/requirements/pip-tools.in
+++ b/requirements/pip-tools.in
@@ -2,5 +2,5 @@
-c lint.pip
-c requirements.pip
-# Minimum version which supports pip>=22.1
-pip-tools>=6.6.1
+# Minimum version which supports pip>=23.2
+pip-tools>=7.0.0
diff --git a/requirements/pip-tools.pip b/requirements/pip-tools.pip
index d87f3d6..4793c2d 100644
--- a/requirements/pip-tools.pip
+++ b/requirements/pip-tools.pip
@@ -10,7 +10,7 @@ click==8.1.3
# via pip-tools
packaging==23.0
# via build
-pip-tools==6.12.3
+pip-tools==7.3.0
# via -r requirements/pip-tools.in
pyproject-hooks==1.0.0
# via build
diff --git a/snekbox/memfs.py b/snekbox/memfs.py
index 991766b..40b57c4 100644
--- a/snekbox/memfs.py
+++ b/snekbox/memfs.py
@@ -144,6 +144,7 @@ class MemFS:
"""
start_time = time.monotonic()
count = 0
+ total_size = 0
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
for file in (Path(self.output, f) for f in files):
if timeout and (time.monotonic() - start_time) > timeout:
@@ -152,10 +153,15 @@ class MemFS:
if not file.is_file():
continue
+ # file.is_file allows file to be a regular file OR a symlink pointing to a regular file.
+ # It is important that we follow symlinks here, so when we check st_size later it is the
+ # size of the underlying file rather than of the symlink.
+ stat = file.stat(follow_symlinks=True)
+
if exclude_files and (orig_time := exclude_files.get(file)):
- new_time = file.stat().st_mtime
+ new_time = stat.st_mtime
log.info(f"Checking {file.name} ({orig_time=}, {new_time=})")
- if file.stat().st_mtime == orig_time:
+ if stat.st_mtime == orig_time:
log.info(f"Skipping {file.name!r} as it has not been modified")
continue
@@ -163,6 +169,14 @@ class MemFS:
log.info(f"Max attachments {limit} reached, skipping remaining files")
break
+ # Due to sparse files and links the total size could end up being greater
+ # than the size limit of the tmpfs. Limit the total size to be read to
+ # prevent high memory usage / OOM when reading files.
+ total_size += stat.st_size
+ if total_size > self.instance_size:
+ log.info(f"Max file size {self.instance_size} reached, skipping remaining files")
+ break
+
count += 1
log.info(f"Found valid file for upload {file.name!r}")
yield FileAttachment.from_path(file, relative_to=self.output)
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index c701d3a..5b06534 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -218,8 +218,9 @@ class NsJailTests(unittest.TestCase):
os.symlink("file", f"file{i}")
"""
).strip()
-
- nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1)
+ # A value higher than the actual memory needed is used to avoid the limit
+ # on total file size being reached before the timeout when reading.
+ nsjail = NsJail(memfs_instance_size=512 * Size.MiB, files_timeout=1)
result = nsjail.python3(["-c", code])
self.assertEqual(result.returncode, None)
self.assertEqual(
@@ -250,6 +251,60 @@ class NsJailTests(unittest.TestCase):
)
self.assertEqual(result.stderr, None)
+ def test_file_parsing_size_limit_sparse_files(self):
+ tmpfs_size = 8 * Size.MiB
+ code = dedent(
+ f"""
+ import os
+ with open("test.txt", "w") as f:
+ os.truncate(f.fileno(), {tmpfs_size // 2 + 1})
+
+ with open("test2.txt", "w") as f:
+ os.truncate(f.fileno(), {tmpfs_size // 2 + 1})
+ """
+ )
+ nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(len(result.files), 1)
+
+ def test_file_parsing_size_limit_sparse_files_large(self):
+ tmpfs_size = 8 * Size.MiB
+ code = dedent(
+ f"""
+ import os
+ with open("test.txt", "w") as f:
+ # Use a very large value to ensure the test fails if the
+ # file is read even if would have been discarded later.
+ os.truncate(f.fileno(), {1024 * Size.TiB})
+ """
+ )
+ nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(len(result.files), 0)
+
+ def test_file_parsing_size_limit_symlinks(self):
+ tmpfs_size = 8 * Size.MiB
+ code = dedent(
+ f"""
+ import os
+ data = "a" * 1024
+ size = {tmpfs_size // 8}
+
+ with open("file", "w") as f:
+ for _ in range(size // 1024):
+ f.write(data)
+
+ for i in range(20):
+ os.symlink("file", f"file{{i}}")
+ """
+ )
+ nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
+ result = nsjail.python3(["-c", code])
+ self.assertEqual(result.returncode, 0)
+ self.assertEqual(len(result.files), 8)
+
def test_file_write_error(self):
"""Test errors during file write."""
result = self.nsjail.python3(