diff options
-rw-r--r-- | snekbox/memfs.py | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/snekbox/memfs.py b/snekbox/memfs.py index 48a26a2..b15008d 100644 --- a/snekbox/memfs.py +++ b/snekbox/memfs.py @@ -1,34 +1,39 @@ """Memory filesystem for snekbox.""" +from __future__ import annotations import logging +import subprocess from functools import cache from pathlib import Path from shutil import rmtree from threading import BoundedSemaphore -from uuid import uuid4, uuid5 +from types import TracebackType +from typing import Type +from uuid import uuid4 -from typing_extensions import Self - -NAMESPACE = "com.snekbox" +MEMFS_SIZE = "2G" log = logging.getLogger(__name__) @cache -def shm_tempdir() -> Path: +def mem_tempdir() -> Path: """Return the snekbox namespace temporary directory.""" - shm = Path("/dev/shm") - if not shm.exists() or not shm.is_dir(): - raise RuntimeError("No /dev/shm found") + tmp = Path("/snekbox/memfs") + if not tmp.exists() or not tmp.is_dir(): + # Create `memfs` and mount it as a tmpfs + tmp.mkdir(parents=True, exist_ok=True) + tmp.chmod(0o777) + subprocess.check_call( + ["mount", "-t", "tmpfs", "-o", f"size={MEMFS_SIZE}", "tmpfs", str(tmp)] + ) - # Create a temporary directory in the snekbox namespace - tempdir = Path(shm, NAMESPACE) - tempdir.mkdir(exist_ok=True) - return tempdir + return tmp class MemoryTempDir: """A temporary directory using tmpfs.""" + assignment_lock = BoundedSemaphore(1) # Only one process can assign a tempdir at a time assigned_names: set[str] = set() # Pool of tempdir names in use @@ -43,20 +48,25 @@ class MemoryTempDir: return None return self.path.name - def __enter__(self) -> Self: + def __enter__(self) -> MemoryTempDir: # Generates a uuid tempdir with self.assignment_lock: for _ in range(10): name = str(uuid4()) if name not in self.assigned_names: - self.path = Path(shm_tempdir(), name) + self.path = Path(mem_tempdir(), name) self.path.mkdir() self.assigned_names.add(name) return self else: raise RuntimeError("Failed to generate a unique tempdir name in 10 attempts") - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.cleanup() def cleanup(self) -> None: |