diff options
| author | 2023-08-30 13:55:33 +0100 | |
|---|---|---|
| committer | 2023-08-30 13:55:33 +0100 | |
| commit | 2b5f8a0d1f795092678f87ded0823f1b16cea63b (patch) | |
| tree | d503649f9fd4481c45398736d5098b563072ff69 | |
| parent | Merge pull request #185 from python-discord/jb3/refactor-eval-deps (diff) | |
| parent | Merge branch 'main' into jb3/shared-mem (diff) | |
Merge pull request #113 from python-discord/jb3/shared-mem
Enable a limited tmpfs for shared memory
| -rw-r--r-- | config/snekbox.cfg | 8 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 81 | 
2 files changed, 72 insertions, 17 deletions
| diff --git a/config/snekbox.cfg b/config/snekbox.cfg index 4e146ec..1bd2ab6 100644 --- a/config/snekbox.cfg +++ b/config/snekbox.cfg @@ -87,6 +87,14 @@ mount {      rw: false  } +mount { +    dst: "/dev/shm" +    fstype: "tmpfs" +    rw: true +    is_bind: false +    options: "size=40m" +} +  cgroup_mem_max: 52428800  cgroup_mem_swap_max: 0  cgroup_mem_mount: "/sys/fs/cgroup/memory" diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index e422de5..fe55290 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -23,6 +23,9 @@ class NsJailTests(unittest.TestCase):          self.logger = logging.getLogger("snekbox.nsjail")          self.logger.setLevel(logging.WARNING) +        # Hard-coded because it's non-trivial to parse the mount options. +        self.shm_mount_size = 40 * Size.MiB +      def eval_code(self, code: str):          return self.nsjail.python3(["-c", code]) @@ -125,6 +128,25 @@ class NsJailTests(unittest.TestCase):          self.assertIn("-9", exit_codes)          self.assertEqual(result.stderr, None) +    def test_multiprocessing_pool(self): +        # Validates that shm is working as expected +        code = dedent( +            """ +            from multiprocessing import Pool + +            def f(x): +                return x*x + +            with Pool(2) as p: +                print(p.map(f, [1, 2, 3])) +        """ +        ) + +        result = self.eval_file(code) + +        self.assertEqual(result.stdout, "[1, 4, 9]\n") +        self.assertEqual(result.returncode, 0) +      def test_read_only_file_system(self):          for path in ("/", "/etc", "/lib", "/lib64", "/snekbox", "/usr"):              with self.subTest(path=path): @@ -390,35 +412,60 @@ class NsJailTests(unittest.TestCase):              log.output,          ) -    def test_shm_and_tmp_not_mounted(self): -        for path in ("/dev/shm", "/run/shm", "/tmp"): -            with self.subTest(path=path): +    def test_tmp_not_mounted(self): +        code = dedent( +            """ +            with open('/tmp/test', 'wb') as file: +                file.write(bytes([255])) +        """ +        ).strip() + +        result = self.eval_file(code) +        self.assertEqual(result.returncode, 1) +        self.assertIn("No such file or directory", result.stdout) +        self.assertEqual(result.stderr, None) + +    def test_multiprocessing_shared_memory(self): +        cases = ( +            (self.shm_mount_size, self.shm_mount_size, 0), +            # Even if the shared memory object is larger than the mount, +            # writing data within the size of the mount should succeed. +            (self.shm_mount_size + 1, self.shm_mount_size, 0), +            (self.shm_mount_size + 1, self.shm_mount_size + 1, 135), +        ) + +        for shm_size, buffer_size, return_code in cases: +            with self.subTest(shm_size=shm_size, buffer_size=buffer_size): +                # Need enough memory for buffer and bytearray plus some overhead. +                mem_max = (buffer_size * 2) + (400 * Size.MiB)                  code = dedent(                      f""" -                    with open('{path}/test', 'wb') as file: -                        file.write(bytes([255])) -                    """ +                    from multiprocessing.shared_memory import SharedMemory + +                    shm = SharedMemory(create=True, size={shm_size}) +                    shm.buf[:{buffer_size}] = bytearray([1] * {buffer_size}) +                """                  ).strip() -                result = self.eval_file(code) -                self.assertEqual(result.returncode, 1) -                self.assertIn("No such file or directory", result.stdout) +                result = self.eval_file(code, nsjail_args=("--cgroup_mem_max", str(mem_max))) + +                self.assertEqual(result.returncode, return_code) +                self.assertEqual(result.stdout, "")                  self.assertEqual(result.stderr, None) -    def test_multiprocessing_shared_memory_disabled(self): +    def test_multiprocessing_shared_memory_mmap_limited(self): +        """The mmap call should be OOM trying to map a large & sparse shared memory object."""          code = dedent( -            """ +            f"""              from multiprocessing.shared_memory import SharedMemory -            try: -                SharedMemory('test', create=True, size=16) -            except FileExistsError: -                pass -            """ + +            SharedMemory(create=True, size={self.nsjail.config.cgroup_mem_max + Size.GiB}) +        """          ).strip()          result = self.eval_file(code)          self.assertEqual(result.returncode, 1) -        self.assertIn("Function not implemented", result.stdout) +        self.assertIn("[Errno 12] Cannot allocate memory", result.stdout)          self.assertEqual(result.stderr, None)      def test_numpy_import(self): | 
