aboutsummaryrefslogtreecommitdiffstats
path: root/tests/snekio
diff options
context:
space:
mode:
Diffstat (limited to 'tests/snekio')
-rw-r--r--tests/snekio/__init__.py0
-rw-r--r--tests/snekio/test_filesystem.py143
-rw-r--r--tests/snekio/test_memfs.py65
-rw-r--r--tests/snekio/test_snekio.py58
4 files changed, 266 insertions, 0 deletions
diff --git a/tests/snekio/__init__.py b/tests/snekio/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/snekio/__init__.py
diff --git a/tests/snekio/test_filesystem.py b/tests/snekio/test_filesystem.py
new file mode 100644
index 0000000..9f6b76d
--- /dev/null
+++ b/tests/snekio/test_filesystem.py
@@ -0,0 +1,143 @@
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager, suppress
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+from uuid import uuid4
+
+from snekbox.snekio.filesystem import UnmountFlags, mount, unmount
+
+
+class LibMountTests(TestCase):
+ temp_dir: TemporaryDirectory
+
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = TemporaryDirectory(prefix="snekbox_tests")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+
+ @contextmanager
+ def get_mount(self):
+ """Yield a valid mount point and unmount after context."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ yield path
+ finally:
+ with suppress(OSError):
+ unmount(path)
+
+ def test_mount(self):
+ """Test normal mounting."""
+ with self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ self.assertTrue(path.exists())
+ self.assertFalse(path.is_mount())
+ # Unmounting should not remove the original folder
+ self.assertTrue(path.exists())
+
+ def test_mount_errors(self):
+ """Test invalid mount errors."""
+ cases = [
+ (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"),
+ (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"),
+ (
+ dict(source="", target=self.temp_dir.name, fs="tmpfs", invalid_opt="?"),
+ OSError,
+ "Invalid argument",
+ ),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ mount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_mount_duplicate(self):
+ """Test attempted mount after mounted."""
+ path = Path(self.temp_dir.name, str(uuid4()))
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ with self.assertRaises(OSError) as cm:
+ mount(source="", target=path, fs="tmpfs")
+ self.assertIn("already a mount point", str(cm.exception))
+ finally:
+ unmount(target=path)
+
+ def test_unmount_flags(self):
+ """Test unmount flags."""
+ flags = [
+ UnmountFlags.MNT_FORCE,
+ UnmountFlags.MNT_DETACH,
+ UnmountFlags.UMOUNT_NOFOLLOW,
+ ]
+ for flag in flags:
+ with self.subTest(flag=flag), self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ unmount(path, flag)
+ self.assertFalse(path.is_mount())
+
+ def test_unmount_flags_expire(self):
+ """Test unmount MNT_EXPIRE behavior."""
+ with self.get_mount() as path:
+ with self.assertRaises(BlockingIOError):
+ unmount(path, UnmountFlags.MNT_EXPIRE)
+
+ def test_unmount_errors(self):
+ """Test invalid unmount errors."""
+ cases = [
+ (dict(target="not/exist"), OSError, "is not a mount point"),
+ (dict(target=Path("not/exist")), OSError, "is not a mount point"),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ unmount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_unmount_invalid_args(self):
+ """Test invalid unmount invalid flag."""
+ with self.get_mount() as path:
+ with self.assertRaises(OSError) as cm:
+ unmount(path, 251)
+ self.assertIn("Invalid argument", str(cm.exception))
+
+ def test_threading(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ paths = [Path(self.temp_dir.name, str(uuid4())) for _ in range(16)]
+
+ for path in paths:
+ path.mkdir()
+ self.assertFalse(path.is_mount())
+
+ try:
+ with ThreadPoolExecutor() as pool:
+ res = list(
+ pool.map(
+ mount,
+ [""] * len(paths),
+ paths,
+ ["tmpfs"] * len(paths),
+ )
+ )
+ self.assertEqual(len(res), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertTrue(path.is_mount())
+
+ unmounts = list(pool.map(unmount, paths))
+ self.assertEqual(len(unmounts), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertFalse(path.is_mount())
+ finally:
+ with suppress(OSError):
+ for path in paths:
+ unmount(path)
diff --git a/tests/snekio/test_memfs.py b/tests/snekio/test_memfs.py
new file mode 100644
index 0000000..cbe2fe4
--- /dev/null
+++ b/tests/snekio/test_memfs.py
@@ -0,0 +1,65 @@
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import ExitStack
+from unittest import TestCase, mock
+from uuid import uuid4
+
+from snekbox.snekio import MemFS
+
+UUID_TEST = uuid4()
+
+
+class MemFSTests(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.logger = logging.getLogger("snekbox.snekio.memfs")
+ self.logger.setLevel(logging.WARNING)
+
+ @mock.patch("snekbox.snekio.memfs.uuid4", lambda: UUID_TEST)
+ def test_assignment_thread_safe(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ # Concurrently create MemFS in threads, check only 1 can be created
+ # Others should result in RuntimeError
+ with ExitStack() as stack:
+ with ThreadPoolExecutor() as executor:
+ memfs: MemFS | None = None
+ # Each future uses enter_context to ensure __exit__ on test exception
+ futures = [
+ executor.submit(lambda: stack.enter_context(MemFS(10))) for _ in range(8)
+ ]
+ for future in futures:
+ # We should have exactly one result and all others RuntimeErrors
+ if err := future.exception():
+ self.assertIsInstance(err, RuntimeError)
+ else:
+ self.assertIsNone(memfs)
+ memfs = future.result()
+
+ # Original memfs should still exist afterwards
+ self.assertIsInstance(memfs, MemFS)
+ self.assertTrue(memfs.path.is_mount())
+
+ def test_cleanup(self):
+ """Test explicit cleanup."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ memfs.cleanup()
+ self.assertFalse(path.exists())
+
+ def test_context_cleanup(self):
+ """Context __exit__ should trigger cleanup."""
+ with MemFS(10) as memfs:
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ self.assertFalse(path.exists())
+
+ def test_implicit_cleanup(self):
+ """Test implicit _cleanup triggered by GC."""
+ memfs = MemFS(10)
+ path = memfs.path
+ self.assertTrue(path.is_mount())
+ # Catch the warning about implicit cleanup
+ with self.assertWarns(ResourceWarning):
+ del memfs
+ self.assertFalse(path.exists())
diff --git a/tests/snekio/test_snekio.py b/tests/snekio/test_snekio.py
new file mode 100644
index 0000000..8f04429
--- /dev/null
+++ b/tests/snekio/test_snekio.py
@@ -0,0 +1,58 @@
+from unittest import TestCase
+
+from snekbox import snekio
+from snekbox.snekio import FileAttachment, IllegalPathError, ParsingError
+
+
+class SnekIOTests(TestCase):
+ def test_safe_path(self) -> None:
+ cases = [
+ ("", ""),
+ ("foo", "foo"),
+ ("foo/bar", "foo/bar"),
+ ("foo/bar.ext", "foo/bar.ext"),
+ ]
+
+ for path, expected in cases:
+ self.assertEqual(snekio.safe_path(path), expected)
+
+ def test_safe_path_raise(self):
+ cases = [
+ ("../foo", IllegalPathError, "File path '../foo' may not traverse beyond root"),
+ ("/foo", IllegalPathError, "File path '/foo' must be relative"),
+ ]
+
+ for path, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ snekio.safe_path(path)
+ self.assertEqual(str(cm.exception), msg)
+
+ def test_file_from_dict(self):
+ cases = [
+ ({"path": "foo", "content": ""}, FileAttachment("foo", b"")),
+ ({"path": "foo"}, FileAttachment("foo", b"")),
+ ({"path": "foo", "content": "Zm9v"}, FileAttachment("foo", b"foo")),
+ ({"path": "foo/bar.ext", "content": "Zm9v"}, FileAttachment("foo/bar.ext", b"foo")),
+ ]
+
+ for data, expected in cases:
+ self.assertEqual(FileAttachment.from_dict(data), expected)
+
+ def test_file_from_dict_error(self):
+ cases = [
+ (
+ {"path": "foo", "content": "9"},
+ ParsingError,
+ "Invalid base64 encoding for file 'foo'",
+ ),
+ (
+ {"path": "var/a.txt", "content": "1="},
+ ParsingError,
+ "Invalid base64 encoding for file 'var/a.txt'",
+ ),
+ ]
+
+ for data, error, msg in cases:
+ with self.assertRaises(error) as cm:
+ FileAttachment.from_dict(data)
+ self.assertEqual(str(cm.exception), msg)