diff options
Diffstat (limited to 'tests/snekio')
| -rw-r--r-- | tests/snekio/__init__.py | 0 | ||||
| -rw-r--r-- | tests/snekio/test_filesystem.py | 143 | ||||
| -rw-r--r-- | tests/snekio/test_memfs.py | 65 | ||||
| -rw-r--r-- | tests/snekio/test_snekio.py | 58 | 
4 files changed, 266 insertions, 0 deletions
diff --git a/tests/snekio/__init__.py b/tests/snekio/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/snekio/__init__.py diff --git a/tests/snekio/test_filesystem.py b/tests/snekio/test_filesystem.py new file mode 100644 index 0000000..9f6b76d --- /dev/null +++ b/tests/snekio/test_filesystem.py @@ -0,0 +1,143 @@ +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, suppress +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase +from uuid import uuid4 + +from snekbox.snekio.filesystem import UnmountFlags, mount, unmount + + +class LibMountTests(TestCase): +    temp_dir: TemporaryDirectory + +    @classmethod +    def setUpClass(cls): +        cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests") + +    @classmethod +    def tearDownClass(cls): +        cls.temp_dir.cleanup() + +    @contextmanager +    def get_mount(self): +        """Yield a valid mount point and unmount after context.""" +        path = Path(self.temp_dir.name, str(uuid4())) +        path.mkdir() +        try: +            mount(source="", target=path, fs="tmpfs") +            yield path +        finally: +            with suppress(OSError): +                unmount(path) + +    def test_mount(self): +        """Test normal mounting.""" +        with self.get_mount() as path: +            self.assertTrue(path.is_mount()) +            self.assertTrue(path.exists()) +        self.assertFalse(path.is_mount()) +        # Unmounting should not remove the original folder +        self.assertTrue(path.exists()) + +    def test_mount_errors(self): +        """Test invalid mount errors.""" +        cases = [ +            (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"), +            (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"), +            ( +                dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"), +                OSError, +                "Invalid argument", +            ), +        ] +        for case, err, msg in cases: +            with self.subTest(case=case): +                with self.assertRaises(err) as cm: +                    mount(**case) +                self.assertIn(msg, str(cm.exception)) + +    def test_mount_duplicate(self): +        """Test attempted mount after mounted.""" +        path = Path(self.temp_dir.name, str(uuid4())) +        path.mkdir() +        try: +            mount(source="", target=path, fs="tmpfs") +            with self.assertRaises(OSError) as cm: +                mount(source="", target=path, fs="tmpfs") +            self.assertIn("already a mount point", str(cm.exception)) +        finally: +            unmount(target=path) + +    def test_unmount_flags(self): +        """Test unmount flags.""" +        flags = [ +            UnmountFlags.MNT_FORCE, +            UnmountFlags.MNT_DETACH, +            UnmountFlags.UMOUNT_NOFOLLOW, +        ] +        for flag in flags: +            with self.subTest(flag=flag), self.get_mount() as path: +                self.assertTrue(path.is_mount()) +                unmount(path, flag) +                self.assertFalse(path.is_mount()) + +    def test_unmount_flags_expire(self): +        """Test unmount MNT_EXPIRE behavior.""" +        with self.get_mount() as path: +            with self.assertRaises(BlockingIOError): +                unmount(path, UnmountFlags.MNT_EXPIRE) + +    def test_unmount_errors(self): +        """Test invalid unmount errors.""" +        cases = [ +            (dict(target="not/exist"), OSError, "is not a mount point"), +            (dict(target=Path("not/exist")), OSError, "is not a mount point"), +        ] +        for case, err, msg in cases: +            with self.subTest(case=case): +                with self.assertRaises(err) as cm: +                    unmount(**case) +                self.assertIn(msg, str(cm.exception)) + +    def test_unmount_invalid_args(self): +        """Test invalid unmount invalid flag.""" +        with self.get_mount() as path: +            with self.assertRaises(OSError) as cm: +                unmount(path, 251) +            self.assertIn("Invalid argument", str(cm.exception)) + +    def test_threading(self): +        """Test concurrent mounting works in multi-thread environments.""" +        paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)] + +        for path in paths: +            path.mkdir() +            self.assertFalse(path.is_mount()) + +        try: +            with ThreadPoolExecutor() as pool: +                res = list( +                    pool.map( +                        mount, +                        [""] * len(paths), +                        paths, +                        ["tmpfs"] * len(paths), +                    ) +                ) +                self.assertEqual(len(res), len(paths)) + +                for path in paths: +                    with self.subTest(path=path): +                        self.assertTrue(path.is_mount()) + +                unmounts = list(pool.map(unmount, paths)) +                self.assertEqual(len(unmounts), len(paths)) + +                for path in paths: +                    with self.subTest(path=path): +                        self.assertFalse(path.is_mount()) +        finally: +            with suppress(OSError): +                for path in paths: +                    unmount(path) diff --git a/tests/snekio/test_memfs.py b/tests/snekio/test_memfs.py new file mode 100644 index 0000000..cbe2fe4 --- /dev/null +++ b/tests/snekio/test_memfs.py @@ -0,0 +1,65 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack +from unittest import TestCase, mock +from uuid import uuid4 + +from snekbox.snekio import MemFS + +UUID_TEST = uuid4() + + +class MemFSTests(TestCase): +    def setUp(self): +        super().setUp() +        self.logger = logging.getLogger("snekbox.snekio.memfs") +        self.logger.setLevel(logging.WARNING) + +    @mock.patch("snekbox.snekio.memfs.uuid4", lambda: UUID_TEST) +    def test_assignment_thread_safe(self): +        """Test concurrent mounting works in multi-thread environments.""" +        # Concurrently create MemFS in threads, check only 1 can be created +        # Others should result in RuntimeError +        with ExitStack() as stack: +            with ThreadPoolExecutor() as executor: +                memfs: MemFS | None = None +                # Each future uses enter_context to ensure __exit__ on test exception +                futures = [ +                    executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8) +                ] +                for future in futures: +                    # We should have exactly one result and all others RuntimeErrors +                    if err := future.exception(): +                        self.assertIsInstance(err, RuntimeError) +                    else: +                        self.assertIsNone(memfs) +                        memfs = future.result() + +                # Original memfs should still exist afterwards +                self.assertIsInstance(memfs, MemFS) +                self.assertTrue(memfs.path.is_mount()) + +    def test_cleanup(self): +        """Test explicit cleanup.""" +        memfs = MemFS(10) +        path = memfs.path +        self.assertTrue(path.is_mount()) +        memfs.cleanup() +        self.assertFalse(path.exists()) + +    def test_context_cleanup(self): +        """Context __exit__ should trigger cleanup.""" +        with MemFS(10) as memfs: +            path = memfs.path +            self.assertTrue(path.is_mount()) +        self.assertFalse(path.exists()) + +    def test_implicit_cleanup(self): +        """Test implicit _cleanup triggered by GC.""" +        memfs = MemFS(10) +        path = memfs.path +        self.assertTrue(path.is_mount()) +        # Catch the warning about implicit cleanup +        with self.assertWarns(ResourceWarning): +            del memfs +        self.assertFalse(path.exists()) diff --git a/tests/snekio/test_snekio.py b/tests/snekio/test_snekio.py new file mode 100644 index 0000000..8f04429 --- /dev/null +++ b/tests/snekio/test_snekio.py @@ -0,0 +1,58 @@ +from unittest import TestCase + +from snekbox import snekio +from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError + + +class SnekIOTests(TestCase): +    def test_safe_path(self) -> None: +        cases = [ +            ("", ""), +            ("foo", "foo"), +            ("foo/bar", "foo/bar"), +            ("foo/bar.ext", "foo/bar.ext"), +        ] + +        for path, expected in cases: +            self.assertEqual(snekio.safe_path(path), expected) + +    def test_safe_path_raise(self): +        cases = [ +            ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"), +            ("/foo", IllegalPathError, "File path '/foo' must be relative"), +        ] + +        for path, error, msg in cases: +            with self.assertRaises(error) as cm: +                snekio.safe_path(path) +            self.assertEqual(str(cm.exception), msg) + +    def test_file_from_dict(self): +        cases = [ +            ({"path": "foo", "content": ""}, FileAttachment("foo", b"")), +            ({"path": "foo"}, FileAttachment("foo", b"")), +            ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")), +            ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")), +        ] + +        for data, expected in cases: +            self.assertEqual(FileAttachment.from_dict(data), expected) + +    def test_file_from_dict_error(self): +        cases = [ +            ( +                {"path": "foo", "content": "9"}, +                ParsingError, +                "Invalid base64 encoding for file 'foo'", +            ), +            ( +                {"path": "var/a.txt", "content": "1="}, +                ParsingError, +                "Invalid base64 encoding for file 'var/a.txt'", +            ), +        ] + +        for data, error, msg in cases: +            with self.assertRaises(error) as cm: +                FileAttachment.from_dict(data) +            self.assertEqual(str(cm.exception), msg)  |