aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ionite34 <[email protected]>2022-11-28 10:33:07 +0800
committerGravatar ionite34 <[email protected]>2022-11-28 10:33:07 +0800
commit407e6b2079ea9f73f52e4972147620d765d96349 (patch)
treee8cd821ad27212178337611df05bba30e65160ba
parentAdd assertions for test_unmount_flags (diff)
Refactor FileAttachment as non generic
-rw-r--r--snekbox/snekio.py24
-rw-r--r--tests/test_nsjail.py8
2 files changed, 10 insertions, 22 deletions
diff --git a/snekbox/snekio.py b/snekbox/snekio.py
index 5023a69..9321eb2 100644
--- a/snekbox/snekio.py
+++ b/snekbox/snekio.py
@@ -5,9 +5,6 @@ from base64 import b64decode, b64encode
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
-from typing import Generic, TypeVar
-
-T = TypeVar("T", str, bytes)
def safe_path(path: str) -> str:
@@ -44,21 +41,21 @@ class IllegalPathError(ParsingError):
@dataclass
-class FileAttachment(Generic[T]):
+class FileAttachment:
"""A file attachment."""
path: str
- content: T
+ content: bytes
@classmethod
- def from_dict(cls, data: dict[str, str]) -> FileAttachment[bytes]:
+ def from_dict(cls, data: dict[str, str]) -> FileAttachment:
"""Convert a dict to an attachment."""
path = safe_path(data["path"])
content = b64decode(data.get("content", ""))
return cls(path, content)
@classmethod
- def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment[bytes]:
+ def from_path(cls, file: Path, relative_to: Path | None = None) -> FileAttachment:
"""
Create an attachment from a file path.
@@ -74,26 +71,17 @@ class FileAttachment(Generic[T]):
"""Size of the attachment."""
return len(self.content)
- def as_bytes(self) -> bytes:
- """Return the attachment as bytes."""
- if isinstance(self.content, bytes):
- return self.content
- return self.content.encode("utf-8")
-
def save_to(self, directory: Path | str) -> None:
"""Write the attachment to a file in `directory`."""
file = Path(directory, self.path)
# Create directories if they don't exist
file.parent.mkdir(parents=True, exist_ok=True)
- if isinstance(self.content, str):
- file.write_text(self.content, encoding="utf-8")
- else:
- file.write_bytes(self.content)
+ file.write_bytes(self.content)
@cached_property
def json(self) -> dict[str, str]:
"""Convert the attachment to a dict."""
- content = b64encode(self.as_bytes()).decode("ascii")
+ content = b64encode(self.content).decode("ascii")
return {
"path": self.path,
"size": self.size,
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index ccbca56..839d3ec 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -26,7 +26,7 @@ class NsJailTests(unittest.TestCase):
return self.nsjail.python3(["-c", code])
def eval_file(self, code: str, name: str = "test.py", **kwargs):
- file = FileAttachment(name, code)
+ file = FileAttachment(name, code.encode())
return self.nsjail.python3([name], [file], **kwargs)
def test_print_returns_0(self):
@@ -59,8 +59,8 @@ class NsJailTests(unittest.TestCase):
def test_multi_files(self):
files = [
- FileAttachment("main.py", "import lib; print(lib.x)"),
- FileAttachment("lib.py", "x = 'hello'"),
+ FileAttachment("main.py", "import lib; print(lib.x)".encode()),
+ FileAttachment("lib.py", "x = 'hello'".encode()),
]
result = self.nsjail.python3(["main.py"], files)
@@ -209,7 +209,7 @@ class NsJailTests(unittest.TestCase):
def test_file_write_error(self):
"""Test errors during file write."""
- result = self.nsjail.python3([""], [FileAttachment("output", "hello")])
+ result = self.nsjail.python3([""], [FileAttachment("output", "hello".encode())])
self.assertEqual(result.returncode, None)
self.assertEqual(result.stdout, "IsADirectoryError: Failed to create file 'output'.")