diff options
| author | 2023-08-29 14:57:50 +0000 | |
|---|---|---|
| committer | 2023-08-29 14:58:10 +0000 | |
| commit | 16b1a13e206fa34bfc7af05363d5e78742e26e40 (patch) | |
| tree | 7bcb00ee73bdfbd76e3be72bcf7f97e228d0b1e7 | |
| parent | Use pip-tools version that supports newer pip versions (diff) | |
Limit total file size read from tmpfs to avoid high memory usage
Diffstat (limited to '')
| -rw-r--r-- | snekbox/memfs.py | 18 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 59 | 
2 files changed, 73 insertions, 4 deletions
| diff --git a/snekbox/memfs.py b/snekbox/memfs.py index 991766b..40b57c4 100644 --- a/snekbox/memfs.py +++ b/snekbox/memfs.py @@ -144,6 +144,7 @@ class MemFS:          """          start_time = time.monotonic()          count = 0 +        total_size = 0          files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)          for file in (Path(self.output, f) for f in files):              if timeout and (time.monotonic() - start_time) > timeout: @@ -152,10 +153,15 @@ class MemFS:              if not file.is_file():                  continue +            # file.is_file allows file to be a regular file OR a symlink pointing to a regular file. +            # It is important that we follow symlinks here, so when we check st_size later it is the +            # size of the underlying file rather than of the symlink. +            stat = file.stat(follow_symlinks=True) +              if exclude_files and (orig_time := exclude_files.get(file)): -                new_time = file.stat().st_mtime +                new_time = stat.st_mtime                  log.info(f"Checking {file.name} ({orig_time=}, {new_time=})") -                if file.stat().st_mtime == orig_time: +                if stat.st_mtime == orig_time:                      log.info(f"Skipping {file.name!r} as it has not been modified")                      continue @@ -163,6 +169,14 @@ class MemFS:                  log.info(f"Max attachments {limit} reached, skipping remaining files")                  break +            # Due to sparse files and links the total size could end up being greater +            # than the size limit of the tmpfs. Limit the total size to be read to +            # prevent high memory usage / OOM when reading files. +            total_size += stat.st_size +            if total_size > self.instance_size: +                log.info(f"Max file size {self.instance_size} reached, skipping remaining files") +                break +              count += 1              log.info(f"Found valid file for upload {file.name!r}")              yield FileAttachment.from_path(file, relative_to=self.output) diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index c701d3a..5b06534 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -218,8 +218,9 @@ class NsJailTests(unittest.TestCase):                  os.symlink("file", f"file{i}")              """          ).strip() - -        nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1) +        # A value higher than the actual memory needed is used to avoid the limit +        # on total file size being reached before the timeout when reading. +        nsjail = NsJail(memfs_instance_size=512 * Size.MiB, files_timeout=1)          result = nsjail.python3(["-c", code])          self.assertEqual(result.returncode, None)          self.assertEqual( @@ -250,6 +251,60 @@ class NsJailTests(unittest.TestCase):          )          self.assertEqual(result.stderr, None) +    def test_file_parsing_size_limit_sparse_files(self): +        tmpfs_size = 8 * Size.MiB +        code = dedent( +            f""" +            import os +            with open("test.txt", "w") as f: +                os.truncate(f.fileno(), {tmpfs_size // 2 + 1}) + +            with open("test2.txt", "w") as f: +                os.truncate(f.fileno(), {tmpfs_size // 2 + 1}) +            """ +        ) +        nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5) +        result = nsjail.python3(["-c", code]) +        self.assertEqual(result.returncode, 0) +        self.assertEqual(len(result.files), 1) + +    def test_file_parsing_size_limit_sparse_files_large(self): +        tmpfs_size = 8 * Size.MiB +        code = dedent( +            f""" +            import os +            with open("test.txt", "w") as f: +                # Use a very large value to ensure the test fails if the +                # file is read even if would have been discarded later. +                os.truncate(f.fileno(), {1024 * Size.TiB}) +            """ +        ) +        nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5) +        result = nsjail.python3(["-c", code]) +        self.assertEqual(result.returncode, 0) +        self.assertEqual(len(result.files), 0) + +    def test_file_parsing_size_limit_symlinks(self): +        tmpfs_size = 8 * Size.MiB +        code = dedent( +            f""" +            import os +            data = "a" * 1024 +            size = {tmpfs_size // 8} + +            with open("file", "w") as f: +                for _ in range(size // 1024): +                    f.write(data) + +            for i in range(20): +                os.symlink("file", f"file{{i}}") +            """ +        ) +        nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5) +        result = nsjail.python3(["-c", code]) +        self.assertEqual(result.returncode, 0) +        self.assertEqual(len(result.files), 8) +      def test_file_write_error(self):          """Test errors during file write."""          result = self.nsjail.python3( | 
