aboutsummaryrefslogtreecommitdiffstats
path: root/tests/test_memfs.py
blob: 8050562fa275f8b0e12bdc17c2ff42cb80933ad2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
from concurrent.futures import ThreadPoolExecutor
from operator import attrgetter
from unittest import TestCase, mock
from uuid import uuid4

from snekbox.memfs import MemFS

UUID_TEST = uuid4()


def get_memfs_with_context():
    return MemFS(10).__enter__()


class MemFSTests(TestCase):
    def setUp(self):
        super().setUp()
        self.logger = logging.getLogger("snekbox.memfs")
        self.logger.setLevel(logging.WARNING)

    @mock.patch("snekbox.memfs.uuid4", lambda: UUID_TEST)
    def test_assignment_thread_safe(self):
        """Test concurrent mounting works in multi-thread environments."""
        # Concurrently create MemFS in threads, check only 1 can be created
        # Others should result in RuntimeError
        with ThreadPoolExecutor() as pool:
            memfs: MemFS | None = None
            futures = [pool.submit(get_memfs_with_context) for _ in range(8)]
            for future in futures:
                # We should have exactly one result and all others RuntimeErrors
                if err := future.exception():
                    self.assertIsInstance(err, RuntimeError)
                else:
                    self.assertIsNone(memfs)
                    memfs = future.result()

            # Original memfs should still exist afterwards
            self.assertIsInstance(memfs, MemFS)
            self.assertTrue(memfs.path.exists())

    def test_no_context_error(self):
        """Accessing MemFS attributes before __enter__ raises RuntimeError."""
        cases = [
            attrgetter("path"),
            attrgetter("name"),
            attrgetter("home"),
            attrgetter("output"),
            lambda fs: fs.mkdir(""),
            lambda fs: list(fs.files(1)),
        ]

        memfs = MemFS(10)
        for case in cases:
            with self.subTest(case=case), self.assertRaises(RuntimeError):
                case(memfs)