diff options
author | 2022-11-28 10:33:07 +0800 | |
---|---|---|
committer | 2022-11-28 10:33:07 +0800 | |
commit | 407e6b2079ea9f73f52e4972147620d765d96349 (patch) | |
tree | e8cd821ad27212178337611df05bba30e65160ba | |
parent | Add assertions for test_unmount_flags (diff) |
Refactor FileAttachment as non generic
-rw-r--r-- | snekbox/snekio.py | 24 | ||||
-rw-r--r-- | tests/test_nsjail.py | 8 |
2 files changed, 10 insertions, 22 deletions
diff --git a/snekbox/snekio.py b/snekbox/snekio.py index 5023a69..9321eb2 100644 --- a/snekbox/snekio.py +++ b/snekbox/snekio.py @@ -5,9 +5,6 @@ from base64 import b64decode, b64encode from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Generic, TypeVar - -T = TypeVar("T", str, bytes) def safe_path(path: str) -> str: @@ -44,21 +41,21 @@ class IllegalPathError(ParsingError): @dataclass -class FileAttachment(Generic[T]): +class FileAttachment: """A file attachment.""" path: str - content: T + content: bytes @classmethod - def from_dict(cls, data: dict[str, str]) -> FileAttachment[bytes]: + def from_dict(cls, data: dict[str, str]) -> FileAttachment: """Convert a dict to an attachment.""" path = safe_path(data["path"]) content = b64decode(data.get("content", "")) return cls(path, content) @classmethod - def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment[bytes]: + def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment: """ Create an attachment from a file path. @@ -74,26 +71,17 @@ class FileAttachment(Generic[T]): """Size of the attachment.""" return len(self.content) - def as_bytes(self) -> bytes: - """Return the attachment as bytes.""" - if isinstance(self.content, bytes): - return self.content - return self.content.encode("utf-8") - def save_to(self, directory: Path | str) -> None: """Write the attachment to a file in `directory`.""" file = Path(directory, self.path) # Create directories if they don't exist file.parent.mkdir(parents=True, exist_ok=True) - if isinstance(self.content, str): - file.write_text(self.content, encoding="utf-8") - else: - file.write_bytes(self.content) + file.write_bytes(self.content) @cached_property def json(self) -> dict[str, str]: """Convert the attachment to a dict.""" - content = b64encode(self.as_bytes()).decode("ascii") + content = b64encode(self.content).decode("ascii") return { "path": self.path, "size": self.size, diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index ccbca56..839d3ec 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -26,7 +26,7 @@ class NsJailTests(unittest.TestCase): return self.nsjail.python3(["-c", code]) def eval_file(self, code: str, name: str = "test.py", **kwargs): - file = FileAttachment(name, code) + file = FileAttachment(name, code.encode()) return self.nsjail.python3([name], [file], **kwargs) def test_print_returns_0(self): @@ -59,8 +59,8 @@ class NsJailTests(unittest.TestCase): def test_multi_files(self): files = [ - FileAttachment("main.py", "import lib; print(lib.x)"), - FileAttachment("lib.py", "x = 'hello'"), + FileAttachment("main.py", "import lib; print(lib.x)".encode()), + FileAttachment("lib.py", "x = 'hello'".encode()), ] result = self.nsjail.python3(["main.py"], files) @@ -209,7 +209,7 @@ class NsJailTests(unittest.TestCase): def test_file_write_error(self): """Test errors during file write.""" - result = self.nsjail.python3([""], [FileAttachment("output", "hello")]) + result = self.nsjail.python3([""], [FileAttachment("output", "hello".encode())]) self.assertEqual(result.returncode, None) self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'output'.") |