diff options
| author | 2022-11-28 10:33:07 +0800 | |
|---|---|---|
| committer | 2022-11-28 10:33:07 +0800 | |
| commit | 407e6b2079ea9f73f52e4972147620d765d96349 (patch) | |
| tree | e8cd821ad27212178337611df05bba30e65160ba | |
| parent | Add assertions for test_unmount_flags (diff) | |
Refactor FileAttachment as non generic
| -rw-r--r-- | snekbox/snekio.py | 24 | ||||
| -rw-r--r-- | tests/test_nsjail.py | 8 | 
2 files changed, 10 insertions, 22 deletions
| diff --git a/snekbox/snekio.py b/snekbox/snekio.py index 5023a69..9321eb2 100644 --- a/snekbox/snekio.py +++ b/snekbox/snekio.py @@ -5,9 +5,6 @@ from base64 import b64decode, b64encode  from dataclasses import dataclass  from functools import cached_property  from pathlib import Path -from typing import Generic, TypeVar - -T = TypeVar("T", str, bytes)  def safe_path(path: str) -> str: @@ -44,21 +41,21 @@ class IllegalPathError(ParsingError):  @dataclass -class FileAttachment(Generic[T]): +class FileAttachment:      """A file attachment."""      path: str -    content: T +    content: bytes      @classmethod -    def from_dict(cls, data: dict[str, str]) -> FileAttachment[bytes]: +    def from_dict(cls, data: dict[str, str]) -> FileAttachment:          """Convert a dict to an attachment."""          path = safe_path(data["path"])          content = b64decode(data.get("content", ""))          return cls(path, content)      @classmethod -    def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment[bytes]: +    def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment:          """          Create an attachment from a file path. @@ -74,26 +71,17 @@ class FileAttachment(Generic[T]):          """Size of the attachment."""          return len(self.content) -    def as_bytes(self) -> bytes: -        """Return the attachment as bytes.""" -        if isinstance(self.content, bytes): -            return self.content -        return self.content.encode("utf-8") -      def save_to(self, directory: Path | str) -> None:          """Write the attachment to a file in `directory`."""          file = Path(directory, self.path)          # Create directories if they don't exist          file.parent.mkdir(parents=True, exist_ok=True) -        if isinstance(self.content, str): -            file.write_text(self.content, encoding="utf-8") -        else: -            file.write_bytes(self.content) +        file.write_bytes(self.content)      @cached_property      def json(self) -> dict[str, str]:          """Convert the attachment to a dict.""" -        content = b64encode(self.as_bytes()).decode("ascii") +        content = b64encode(self.content).decode("ascii")          return {              "path": self.path,              "size": self.size, diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index ccbca56..839d3ec 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -26,7 +26,7 @@ class NsJailTests(unittest.TestCase):          return self.nsjail.python3(["-c", code])      def eval_file(self, code: str, name: str = "test.py", **kwargs): -        file = FileAttachment(name, code) +        file = FileAttachment(name, code.encode())          return self.nsjail.python3([name], [file], **kwargs)      def test_print_returns_0(self): @@ -59,8 +59,8 @@ class NsJailTests(unittest.TestCase):      def test_multi_files(self):          files = [ -            FileAttachment("main.py", "import lib; print(lib.x)"), -            FileAttachment("lib.py", "x = 'hello'"), +            FileAttachment("main.py", "import lib; print(lib.x)".encode()), +            FileAttachment("lib.py", "x = 'hello'".encode()),          ]          result = self.nsjail.python3(["main.py"], files) @@ -209,7 +209,7 @@ class NsJailTests(unittest.TestCase):      def test_file_write_error(self):          """Test errors during file write.""" -        result = self.nsjail.python3([""], [FileAttachment("output", "hello")]) +        result = self.nsjail.python3([""], [FileAttachment("output", "hello".encode())])          self.assertEqual(result.returncode, None)          self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'output'.") | 
