aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Ionite <[email protected]>2022-11-20 21:34:54 -0500
committerGravatar Ionite <[email protected]>2022-11-20 21:34:54 -0500
commit69ff6809331bcf6097a71c6f7b1087fe344c8797 (patch)
tree468718291881d63a6284181ed23a03741320399c
parentAdd unit test for multiple python files (diff)
Combined file handling to FileAttachment class
-rw-r--r--snekbox/api/resources/eval.py14
-rw-r--r--snekbox/nsjail.py4
-rw-r--r--snekbox/snekio.py77
-rw-r--r--tests/test_nsjail.py8
4 files changed, 63 insertions, 40 deletions
diff --git a/snekbox/api/resources/eval.py b/snekbox/api/resources/eval.py
index 9244d7e..af843d2 100644
--- a/snekbox/api/resources/eval.py
+++ b/snekbox/api/resources/eval.py
@@ -9,7 +9,7 @@ from snekbox.nsjail import NsJail
__all__ = ("EvalResource",)
-from snekbox.snekio import EvalRequestFile, FileParsingError
+from snekbox.snekio import FileAttachment, ParsingError
log = logging.getLogger(__name__)
@@ -32,7 +32,11 @@ class EvalResource:
"type": "array",
"items": {
"type": "object",
- "properties": {"name": {"type": "string"}, "content": {"type": "string"}},
+ "properties": {
+ "name": {"type": "string"},
+ "content-encoding": {"type": "string"},
+ "content": {"type": "string"},
+ },
"required": ["name"],
},
},
@@ -103,10 +107,10 @@ class EvalResource:
try:
result = self.nsjail.python3(
py_args=req.media["args"],
- files=[EvalRequestFile.from_dict(file) for file in req.media.get("files", [])],
+ files=[FileAttachment.from_dict(file) for file in req.media.get("files", [])],
)
- except FileParsingError as e:
- raise falcon.HTTPBadRequest("Invalid file in request", str(e))
+ except ParsingError as e:
+ raise falcon.HTTPBadRequest(description=f"Invalid file in request: {e}")
except Exception:
log.exception("An exception occurred while trying to process the request")
raise falcon.HTTPInternalServerError
diff --git a/snekbox/nsjail.py b/snekbox/nsjail.py
index 723bd30..69b6599 100644
--- a/snekbox/nsjail.py
+++ b/snekbox/nsjail.py
@@ -14,7 +14,7 @@ from snekbox.memfs import MemFS
__all__ = ("NsJail",)
from snekbox.process import EvalResult
-from snekbox.snekio import AttachmentError, EvalRequestFile
+from snekbox.snekio import AttachmentError, FileAttachment
log = logging.getLogger(__name__)
@@ -139,7 +139,7 @@ class NsJail:
def python3(
self,
py_args: Iterable[str],
- files: Iterable[EvalRequestFile] = (),
+ files: Iterable[FileAttachment] = (),
*,
nsjail_args: Iterable[str] = (),
) -> EvalResult:
diff --git a/snekbox/snekio.py b/snekbox/snekio.py
index 26acd04..020ca8c 100644
--- a/snekbox/snekio.py
+++ b/snekbox/snekio.py
@@ -1,12 +1,16 @@
+"""I/O Operations for sending / receiving files from the sandbox."""
from __future__ import annotations
import zlib
-from base64 import b64encode
+from base64 import b64decode, b64encode
from dataclasses import dataclass
from pathlib import Path
+from typing import Generic, TypeVar
RequestType = dict[str, str | bool | list[str | dict[str, str]]]
+T = TypeVar("T", str, bytes)
+
def sizeof_fmt(num: int, suffix: str = "B") -> str:
"""Return a human-readable file size."""
@@ -21,50 +25,47 @@ class AttachmentError(ValueError):
"""Raised when an attachment is invalid."""
-class FileParsingError(ValueError):
- """Raised when a request file cannot be parsed."""
+class ParsingError(AttachmentError):
+ """Raised when an incoming file cannot be parsed."""
+
+
+class IllegalPathError(AttachmentError):
+ """Raised when an attachment has an illegal path."""
@dataclass
-class EvalRequestFile:
- """A file sent in an eval request."""
+class FileAttachment(Generic[T]):
+ """A file attachment."""
name: str
- content: str
+ content: T
@classmethod
- def from_dict(cls, data: dict[str, str]) -> EvalRequestFile:
- """Convert a dict to a str attachment."""
+ def from_dict(cls, data: dict[str, str]) -> FileAttachment:
+ """Convert a dict to an attachment."""
name = data["name"]
path = Path(name)
parts = path.parts
if path.is_absolute() or set(parts[0]) & {"\\", "/"}:
- raise FileParsingError(f"File path '{name}' must be relative")
+ raise IllegalPathError(f"File path '{name}' must be relative")
if any(set(part) == {"."} for part in parts):
- raise FileParsingError(f"File path '{name}' may not use traversal ('..')")
-
- return cls(name, data.get("content", ""))
+ raise IllegalPathError(f"File path '{name}' may not use traversal ('..')")
- def save_to(self, directory: Path) -> None:
- """Save the attachment to a path directory."""
- file = Path(directory, self.name)
- # Create directories if they don't exist
- file.parent.mkdir(parents=True, exist_ok=True)
- file.write_text(self.content)
+ match data.get("content-encoding"):
+ case "base64":
+ content = b64decode(data["content"])
+ case None | "utf-8" | "":
+ content = data["content"].encode("utf-8")
+ case _:
+ raise ParsingError(f"Unknown content encoding '{data['content-encoding']}'")
-
-@dataclass
-class FileAttachment:
- """A file attachment."""
-
- name: str
- content: bytes
+ return cls(name, content)
@classmethod
- def from_path(cls, file: Path, max_size: int | None = None) -> FileAttachment:
- """Create an attachment from a path."""
+ def from_path(cls, file: Path, max_size: int | None = None) -> FileAttachment[bytes]:
+ """Create an attachment from a file path."""
size = file.stat().st_size
if max_size is not None and size > max_size:
raise AttachmentError(
@@ -78,12 +79,30 @@ class FileAttachment:
"""Size of the attachment."""
return len(self.content)
+ def as_bytes(self) -> bytes:
+ """Return the attachment as bytes."""
+ if isinstance(self.content, bytes):
+ return self.content
+ return self.content.encode("utf-8")
+
+ def save_to(self, directory: Path) -> None:
+ """Save the attachment to a path directory."""
+ file = Path(directory, self.name)
+ # Create directories if they don't exist
+ file.parent.mkdir(parents=True, exist_ok=True)
+ if isinstance(self.content, str):
+ file.write_text(self.content, encoding="utf-8")
+ else:
+ file.write_bytes(self.content)
+
def to_dict(self) -> dict[str, str]:
"""Convert the attachment to a dict."""
- cmp = zlib.compress(self.content)
- content = b64encode(cmp).decode("ascii")
+ comp = zlib.compress(self.as_bytes())
+ content = b64encode(comp).decode("ascii")
+
return {
"name": self.name,
"size": self.size,
+ "content-encoding": "base64+zlib",
"content": content,
}
diff --git a/tests/test_nsjail.py b/tests/test_nsjail.py
index c61e2e5..ba6cac7 100644
--- a/tests/test_nsjail.py
+++ b/tests/test_nsjail.py
@@ -10,7 +10,7 @@ from pathlib import Path
from textwrap import dedent
from snekbox.nsjail import NsJail
-from snekbox.snekio import EvalRequestFile
+from snekbox.snekio import FileAttachment
class NsJailTests(unittest.TestCase):
@@ -25,7 +25,7 @@ class NsJailTests(unittest.TestCase):
return self.nsjail.python3(["-c", code])
def eval_file(self, code: str, name: str = "test.py"):
- file = EvalRequestFile(name, code)
+ file = FileAttachment(name, code)
return self.nsjail.python3([name], [file])
def test_print_returns_0(self):
@@ -58,8 +58,8 @@ class NsJailTests(unittest.TestCase):
def test_multi_files(self):
files = [
- EvalRequestFile("main.py", "import lib; print(lib.x)"),
- EvalRequestFile("lib.py", "x = 'hello'"),
+ FileAttachment("main.py", "import lib; print(lib.x)"),
+ FileAttachment("lib.py", "x = 'hello'"),
]
result = self.nsjail.python3(["main.py"], files)