diff options
author | 2022-11-20 21:34:54 -0500 | |
---|---|---|
committer | 2022-11-20 21:34:54 -0500 | |
commit | 69ff6809331bcf6097a71c6f7b1087fe344c8797 (patch) | |
tree | 468718291881d63a6284181ed23a03741320399c | |
parent | Add unit test for multiple python files (diff) |
Combined file handling to FileAttachment class
-rw-r--r-- | snekbox/api/resources/eval.py | 14 | ||||
-rw-r--r-- | snekbox/nsjail.py | 4 | ||||
-rw-r--r-- | snekbox/snekio.py | 77 | ||||
-rw-r--r-- | tests/test_nsjail.py | 8 |
4 files changed, 63 insertions, 40 deletions
diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py index 9244d7e..af843d2 100644 --- a/snekbox/api/resources/eval.py +++ b/snekbox/api/resources/eval.py @@ -9,7 +9,7 @@ from snekbox.nsjail import NsJail __all__ = ("EvalResource",) -from snekbox.snekio import EvalRequestFile, FileParsingError +from snekbox.snekio import FileAttachment, ParsingError log = logging.getLogger(__name__) @@ -32,7 +32,11 @@ class EvalResource: "type": "array", "items": { "type": "object", - "properties": {"name": {"type": "string"}, "content": {"type": "string"}}, + "properties": { + "name": {"type": "string"}, + "content-encoding": {"type": "string"}, + "content": {"type": "string"}, + }, "required": ["name"], }, }, @@ -103,10 +107,10 @@ class EvalResource: try: result = self.nsjail.python3( py_args=req.media["args"], - files=[EvalRequestFile.from_dict(file) for file in req.media.get("files", [])], + files=[FileAttachment.from_dict(file) for file in req.media.get("files", [])], ) - except FileParsingError as e: - raise falcon.HTTPBadRequest("Invalid file in request", str(e)) + except ParsingError as e: + raise falcon.HTTPBadRequest(description=f"Invalid file in request: {e}") except Exception: log.exception("An exception occurred while trying to process the request") raise falcon.HTTPInternalServerError diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py index 723bd30..69b6599 100644 --- a/snekbox/nsjail.py +++ b/snekbox/nsjail.py @@ -14,7 +14,7 @@ from snekbox.memfs import MemFS __all__ = ("NsJail",) from snekbox.process import EvalResult -from snekbox.snekio import AttachmentError, EvalRequestFile +from snekbox.snekio import AttachmentError, FileAttachment log = logging.getLogger(__name__) @@ -139,7 +139,7 @@ class NsJail: def python3( self, py_args: Iterable[str], - files: Iterable[EvalRequestFile] = (), + files: Iterable[FileAttachment] = (), *, nsjail_args: Iterable[str] = (), ) -> EvalResult: diff --git a/snekbox/snekio.py b/snekbox/snekio.py index 26acd04..020ca8c 100644 --- a/snekbox/snekio.py +++ b/snekbox/snekio.py @@ -1,12 +1,16 @@ +"""I/O Operations for sending / receiving files from the sandbox.""" from __future__ import annotations import zlib -from base64 import b64encode +from base64 import b64decode, b64encode from dataclasses import dataclass from pathlib import Path +from typing import Generic, TypeVar RequestType = dict[str, str | bool | list[str | dict[str, str]]] +T = TypeVar("T", str, bytes) + def sizeof_fmt(num: int, suffix: str = "B") -> str: """Return a human-readable file size.""" @@ -21,50 +25,47 @@ class AttachmentError(ValueError): """Raised when an attachment is invalid.""" -class FileParsingError(ValueError): - """Raised when a request file cannot be parsed.""" +class ParsingError(AttachmentError): + """Raised when an incoming file cannot be parsed.""" + + +class IllegalPathError(AttachmentError): + """Raised when an attachment has an illegal path.""" @dataclass -class EvalRequestFile: - """A file sent in an eval request.""" +class FileAttachment(Generic[T]): + """A file attachment.""" name: str - content: str + content: T @classmethod - def from_dict(cls, data: dict[str, str]) -> EvalRequestFile: - """Convert a dict to a str attachment.""" + def from_dict(cls, data: dict[str, str]) -> FileAttachment: + """Convert a dict to an attachment.""" name = data["name"] path = Path(name) parts = path.parts if path.is_absolute() or set(parts[0]) & {"\\", "/"}: - raise FileParsingError(f"File path '{name}' must be relative") + raise IllegalPathError(f"File path '{name}' must be relative") if any(set(part) == {"."} for part in parts): - raise FileParsingError(f"File path '{name}' may not use traversal ('..')") - - return cls(name, data.get("content", "")) + raise IllegalPathError(f"File path '{name}' may not use traversal ('..')") - def save_to(self, directory: Path) -> None: - """Save the attachment to a path directory.""" - file = Path(directory, self.name) - # Create directories if they don't exist - file.parent.mkdir(parents=True, exist_ok=True) - file.write_text(self.content) + match data.get("content-encoding"): + case "base64": + content = b64decode(data["content"]) + case None | "utf-8" | "": + content = data["content"].encode("utf-8") + case _: + raise ParsingError(f"Unknown content encoding '{data['content-encoding']}'") - -@dataclass -class FileAttachment: - """A file attachment.""" - - name: str - content: bytes + return cls(name, content) @classmethod - def from_path(cls, file: Path, max_size: int | None = None) -> FileAttachment: - """Create an attachment from a path.""" + def from_path(cls, file: Path, max_size: int | None = None) -> FileAttachment[bytes]: + """Create an attachment from a file path.""" size = file.stat().st_size if max_size is not None and size > max_size: raise AttachmentError( @@ -78,12 +79,30 @@ class FileAttachment: """Size of the attachment.""" return len(self.content) + def as_bytes(self) -> bytes: + """Return the attachment as bytes.""" + if isinstance(self.content, bytes): + return self.content + return self.content.encode("utf-8") + + def save_to(self, directory: Path) -> None: + """Save the attachment to a path directory.""" + file = Path(directory, self.name) + # Create directories if they don't exist + file.parent.mkdir(parents=True, exist_ok=True) + if isinstance(self.content, str): + file.write_text(self.content, encoding="utf-8") + else: + file.write_bytes(self.content) + def to_dict(self) -> dict[str, str]: """Convert the attachment to a dict.""" - cmp = zlib.compress(self.content) - content = b64encode(cmp).decode("ascii") + comp = zlib.compress(self.as_bytes()) + content = b64encode(comp).decode("ascii") + return { "name": self.name, "size": self.size, + "content-encoding": "base64+zlib", "content": content, } diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py index c61e2e5..ba6cac7 100644 --- a/tests/test_nsjail.py +++ b/tests/test_nsjail.py @@ -10,7 +10,7 @@ from pathlib import Path from textwrap import dedent from snekbox.nsjail import NsJail -from snekbox.snekio import EvalRequestFile +from snekbox.snekio import FileAttachment class NsJailTests(unittest.TestCase): @@ -25,7 +25,7 @@ class NsJailTests(unittest.TestCase): return self.nsjail.python3(["-c", code]) def eval_file(self, code: str, name: str = "test.py"): - file = EvalRequestFile(name, code) + file = FileAttachment(name, code) return self.nsjail.python3([name], [file]) def test_print_returns_0(self): @@ -58,8 +58,8 @@ class NsJailTests(unittest.TestCase): def test_multi_files(self): files = [ - EvalRequestFile("main.py", "import lib; print(lib.x)"), - EvalRequestFile("lib.py", "x = 'hello'"), + FileAttachment("main.py", "import lib; print(lib.x)"), + FileAttachment("lib.py", "x = 'hello'"), ] result = self.nsjail.python3(["main.py"], files) |