from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from pathlib import Path from unittest import TestCase from uuid import uuid4 from snekbox import libmount class LibMountTests(TestCase): def setUp(self): self.temp_dir = Path("/tmp/snekbox-test") self.temp_dir.mkdir(exist_ok=True, parents=True) super().setUp() @contextmanager def get_mount(self): """Yield a valid mount point and unmount after context.""" path = self.temp_dir / str(uuid4()) path.mkdir() try: libmount.mount(source="", target=path, fs="tmpfs") yield path finally: with suppress(OSError): libmount.unmount(path) def test_mount(self): """Test normal mounting.""" with self.get_mount() as path: self.assertTrue(path.is_mount()) self.assertTrue(path.exists()) self.assertFalse(path.is_mount()) # Unmounting should not remove the original folder self.assertTrue(path.exists()) def test_mount_errors(self): """Test invalid mount errors.""" cases = [ (dict(source="", target=str(uuid4()), fs="tmpfs"), OSError, "No such file"), (dict(source=str(uuid4()), target="some/dir", fs="tmpfs"), OSError, "No such file"), ( dict(source="", target=self.temp_dir, fs="tmpfs", invalid_opt="?"), OSError, "Invalid argument", ), ] for case, err, msg in cases: with self.subTest(case=case): with self.assertRaises(err) as cm: libmount.mount(**case) self.assertIn(msg, str(cm.exception)) def test_mount_duplicate(self): """Test attempted mount after mounted.""" path = self.temp_dir / str(uuid4()) path.mkdir() try: libmount.mount(source="", target=path, fs="tmpfs") with self.assertRaises(OSError) as cm: libmount.mount(source="", target=path, fs="tmpfs") self.assertIn("already a mount point", str(cm.exception)) finally: libmount.unmount(target=path) def test_unmount_flags(self): """Test unmount flags.""" flags = [ libmount.UnmountFlags.MNT_FORCE, libmount.UnmountFlags.MNT_DETACH, libmount.UnmountFlags.UMOUNT_NOFOLLOW, ] for flag in flags: with self.subTest(flag=flag), self.get_mount() as path: libmount.unmount(path, flag) def test_unmount_flags_expire(self): """Test unmount MNT_EXPIRE behavior.""" with self.get_mount() as path: with self.assertRaises(BlockingIOError): libmount.unmount(path, libmount.UnmountFlags.MNT_EXPIRE) def test_unmount_errors(self): """Test invalid unmount errors.""" cases = [ (dict(target="not/exist"), OSError, "is not a mount point"), (dict(target=Path("not/exist")), OSError, "is not a mount point"), ] for case, err, msg in cases: with self.subTest(case=case): with self.assertRaises(err) as cm: libmount.unmount(**case) self.assertIn(msg, str(cm.exception)) def test_unmount_invalid_args(self): """Test invalid unmount invalid flag.""" with self.get_mount() as path: with self.assertRaises(OSError) as cm: libmount.unmount(path, 251) self.assertIn("Invalid argument", str(cm.exception)) def test_threading(self): """Test concurrent mounting works in multi-thread environments.""" paths = [self.temp_dir / str(uuid4()) for _ in range(16)] for path in paths: path.mkdir() self.assertFalse(path.is_mount()) try: with ThreadPoolExecutor() as pool: res = list( pool.map( libmount.mount, [""] * len(paths), paths, ["tmpfs"] * len(paths), ) ) self.assertEqual(len(res), len(paths)) for path in paths: with self.subTest(path=path): self.assertTrue(path.is_mount()) unmounts = list(pool.map(libmount.unmount, paths)) self.assertEqual(len(unmounts), len(paths)) for path in paths: with self.subTest(path=path): self.assertFalse(path.is_mount()) finally: with suppress(OSError): for path in paths: libmount.unmount(path)