aboutsummaryrefslogtreecommitdiffstats
path: root/tests/test_filesystem.py
diff options
context:
space:
mode:
authorGravatar ionite34 <[email protected]>2022-11-28 11:40:23 +0800
committerGravatar ionite34 <[email protected]>2022-11-28 11:40:23 +0800
commite72032fc3090d6e5d5d961a3aafaeec4cce787d7 (patch)
tree79063c704b9c03fda044088a9c50a83d01537a04 /tests/test_filesystem.py
parentAdd tests for invalid base64 file content (diff)
Rename `libmount` to `filesystem`
Diffstat (limited to 'tests/test_filesystem.py')
-rw-r--r--tests/test_filesystem.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py
new file mode 100644
index 0000000..b33322c
--- /dev/null
+++ b/tests/test_filesystem.py
@@ -0,0 +1,143 @@
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager, suppress
+from pathlib import Path
+from shutil import rmtree
+from unittest import TestCase
+from uuid import uuid4
+
+from snekbox.filesystem import UnmountFlags, mount, unmount
+
+
+class LibMountTests(TestCase):
+ temp_dir = Path("/tmp/snekbox")
+
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir.mkdir(exist_ok=True, parents=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ rmtree(cls.temp_dir, ignore_errors=True)
+
+ @contextmanager
+ def get_mount(self):
+ """Yield a valid mount point and unmount after context."""
+ path = self.temp_dir / str(uuid4())
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ yield path
+ finally:
+ with suppress(OSError):
+ unmount(path)
+
+ def test_mount(self):
+ """Test normal mounting."""
+ with self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ self.assertTrue(path.exists())
+ self.assertFalse(path.is_mount())
+ # Unmounting should not remove the original folder
+ self.assertTrue(path.exists())
+
+ def test_mount_errors(self):
+ """Test invalid mount errors."""
+ cases = [
+ (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"),
+ (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"),
+ (
+ dict(source="", target=self.temp_dir, fs="tmpfs", invalid_opt="?"),
+ OSError,
+ "Invalid argument",
+ ),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ mount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_mount_duplicate(self):
+ """Test attempted mount after mounted."""
+ path = self.temp_dir / str(uuid4())
+ path.mkdir()
+ try:
+ mount(source="", target=path, fs="tmpfs")
+ with self.assertRaises(OSError) as cm:
+ mount(source="", target=path, fs="tmpfs")
+ self.assertIn("already a mount point", str(cm.exception))
+ finally:
+ unmount(target=path)
+
+ def test_unmount_flags(self):
+ """Test unmount flags."""
+ flags = [
+ UnmountFlags.MNT_FORCE,
+ UnmountFlags.MNT_DETACH,
+ UnmountFlags.UMOUNT_NOFOLLOW,
+ ]
+ for flag in flags:
+ with self.subTest(flag=flag), self.get_mount() as path:
+ self.assertTrue(path.is_mount())
+ unmount(path, flag)
+ self.assertFalse(path.is_mount())
+
+ def test_unmount_flags_expire(self):
+ """Test unmount MNT_EXPIRE behavior."""
+ with self.get_mount() as path:
+ with self.assertRaises(BlockingIOError):
+ unmount(path, UnmountFlags.MNT_EXPIRE)
+
+ def test_unmount_errors(self):
+ """Test invalid unmount errors."""
+ cases = [
+ (dict(target="not/exist"), OSError, "is not a mount point"),
+ (dict(target=Path("not/exist")), OSError, "is not a mount point"),
+ ]
+ for case, err, msg in cases:
+ with self.subTest(case=case):
+ with self.assertRaises(err) as cm:
+ unmount(**case)
+ self.assertIn(msg, str(cm.exception))
+
+ def test_unmount_invalid_args(self):
+ """Test invalid unmount invalid flag."""
+ with self.get_mount() as path:
+ with self.assertRaises(OSError) as cm:
+ unmount(path, 251)
+ self.assertIn("Invalid argument", str(cm.exception))
+
+ def test_threading(self):
+ """Test concurrent mounting works in multi-thread environments."""
+ paths = [self.temp_dir / str(uuid4()) for _ in range(16)]
+
+ for path in paths:
+ path.mkdir()
+ self.assertFalse(path.is_mount())
+
+ try:
+ with ThreadPoolExecutor() as pool:
+ res = list(
+ pool.map(
+ mount,
+ [""] * len(paths),
+ paths,
+ ["tmpfs"] * len(paths),
+ )
+ )
+ self.assertEqual(len(res), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertTrue(path.is_mount())
+
+ unmounts = list(pool.map(unmount, paths))
+ self.assertEqual(len(unmounts), len(paths))
+
+ for path in paths:
+ with self.subTest(path=path):
+ self.assertFalse(path.is_mount())
+ finally:
+ with suppress(OSError):
+ for path in paths:
+ unmount(path)