aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ionite34 <[email protected]>2022-11-30 07:48:12 +0800
committerGravatar ionite34 <[email protected]>2022-11-30 07:48:12 +0800
commit51e7e17e06aedc9f05347a5196faf137dc5a00ae (patch)
tree197f3abfd3654c1c171f178242f345504deacb32
parentUpdate unit test (diff)
Refactors for EvalResult and EvalJob dataclasses
-rw-r--r--bot/exts/utils/snekbox.py296
-rw-r--r--bot/exts/utils/snekio.py64
2 files changed, 185 insertions, 175 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 75c6f2d3a..d8a3088a6 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -3,25 +3,22 @@ from __future__ import annotations
import asyncio
import contextlib
import re
-from base64 import b64decode, b64encode
-from collections.abc import Iterable
from dataclasses import dataclass, field
from functools import partial
-from io import BytesIO
from operator import attrgetter
-from pathlib import Path
from signal import Signals
from textwrap import dedent
-from typing import Generic, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar
+from typing import Literal, Optional, TYPE_CHECKING, Tuple
from botcore.utils import interactions
from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX
-from discord import AllowedMentions, File, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui
+from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui
from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only
from bot.bot import Bot
from bot.constants import Categories, Channels, MODERATION_ROLES, Roles, URLs
from bot.decorators import redirect_output
+from bot.exts.utils.snekio import FileAttachment, sizeof_fmt, FILE_SIZE_LIMIT
from bot.log import get_logger
from bot.utils import send_to_paste_service
from bot.utils.lock import LockedResourceError, lock_arg
@@ -88,89 +85,108 @@ SIGKILL = 9
REDO_EMOJI = '\U0001f501' # :repeat:
REDO_TIMEOUT = 30
-# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers
-FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB
-
-T = TypeVar("T")
-
-
-def sizeof_fmt(num: int, suffix: str = "B") -> str:
- """Return a human-readable file size."""
- for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
- if abs(num) < 1024:
- return f"{num:3.1f}{unit}{suffix}"
- num /= 1024
- return f"{num:.1f}Yi{suffix}"
+PythonVersion = Literal["3.10", "3.11"]
@dataclass
-class FileAttachment(Generic[T]):
- """File Attachment from Snekbox eval."""
-
- path: str
- content: T
+class EvalJob:
+ """Job to be evaluated by snekbox."""
- def __repr__(self) -> str:
- """Return the content as a string."""
- content = self.content if isinstance(self.content, str) else "(...)"
- return f"FileAttachment(path={self.path}, content={content})"
+ args: list[str]
+ files: list[FileAttachment] = field(default_factory=list)
+ name: str = "eval"
+ version: PythonVersion = "3.11"
@classmethod
- def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment[bytes]:
- """Create a FileAttachment from a dict response."""
- size = data.get("size")
- if (size and size > size_limit) or (len(data["content"]) > size_limit):
- raise ValueError("File size exceeds limit")
+ def from_code(cls, code: str, path: str = "main.py") -> EvalJob:
+ """Create an EvalJob from a code string."""
+ return cls(
+ args=[path],
+ files=[FileAttachment(path, code.encode())],
+ )
- content = b64decode(data["content"])
+ def as_version(self, version: PythonVersion) -> EvalJob:
+ """Return a copy of the job with a different Python version."""
+ return EvalJob(
+ args=self.args,
+ files=self.files,
+ name=self.name,
+ version=version,
+ )
- if len(content) > size_limit:
- raise ValueError("File size exceeds limit")
+ def to_dict(self) -> dict[str, list[str | dict[str, str]]]:
+ """Convert the job to a dict."""
+ return {
+ "args": self.args,
+ "files": [file.to_dict() for file in self.files],
+ }
- return cls(data["path"], content)
- def to_json(self) -> dict[str, str]:
- """Convert the attachment to a json dict."""
- content = self.content
- if isinstance(content, str):
- content = content.encode("utf-8")
+@dataclass(frozen=True)
+class EvalResult:
+ """The result of an eval job."""
- return {
- "path": self.path,
- "content": b64encode(content).decode("ascii"),
- }
+ stdout: str
+ returncode: int | None
+ files: list[FileAttachment] = field(default_factory=list)
+ err_files: list[str] = field(default_factory=list)
- def to_file(self) -> File:
- """Convert to a discord.File."""
- name = Path(self.path).name
- return File(BytesIO(self.content), filename=name)
+ @property
+ def status_emoji(self):
+ """Return an emoji corresponding to the status code or lack of output in result."""
+ # If there are attachments, skip empty output warning
+ if not self.stdout.strip() and not self.files: # No output
+ return ":warning:"
+ elif self.returncode == 0: # No error
+ return ":white_check_mark:"
+ else: # Exception
+ return ":x:"
+ def message(self, job: EvalJob) -> tuple[str, str]:
+ """Return a user-friendly message and error corresponding to the process's return code."""
+ msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}"
+ error = ""
-@dataclass
-class EvalJob:
- """Represents a job to be evaluated by Snekbox."""
+ if self.returncode is None:
+ msg = f"Your {job.version} {job.name} job has failed"
+ error = self.stdout.strip()
+ elif self.returncode == 128 + SIGKILL:
+ msg = f"Your {job.version} {job.name} job timed out or ran out of memory"
+ elif self.returncode == 255:
+ msg = f"Your {job.version} {job.name} job has failed"
+ error = "A fatal NsJail error occurred"
+ else:
+ # Try to append signal's name if one exists
+ with contextlib.suppress(ValueError):
+ name = Signals(self.returncode - 128).name
+ msg = f"{msg} ({name})"
- args: list[str]
- files: list[FileAttachment] = field(default_factory=list)
+ # Add error message for failed attachments
+ if self.err_files:
+ failed_files = f"({', '.join(self.err_files)})"
+ msg += (
+ f".\n\n> Some attached files were not able to be uploaded {failed_files}."
+ f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}"
+ )
- def __repr__(self) -> str:
- """Return the job as a string."""
- return f"EvalJob(args={self.args}, files={self.files})"
+ return msg, error
@classmethod
- def from_code(cls, code: str, files: Iterable[FileAttachment] = (), name: str = "main.py") -> EvalJob:
- """Create an EvalJob from a code string."""
- return cls(
- args=[name],
- files=[FileAttachment(name, code), *files],
+ def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult:
+ """Create an EvalResult from a dict."""
+ res = cls(
+ stdout=data["stdout"],
+ returncode=data["returncode"],
)
- def to_json(self) -> dict[str, list[str | dict[str, str]]]:
- """Convert the job to a dict."""
- return {
- "args": self.args,
- "files": [file.to_json() for file in self.files],
- }
+ for file in data.get("files", []):
+ try:
+ res.files.append(FileAttachment.from_dict(file))
+ except ValueError as e:
+ log.info(f"Failed to parse file from snekbox response: {e}")
+ res.err_files.append(file["path"])
+
+ return res
class CodeblockConverter(Converter):
@@ -214,19 +230,17 @@ class PythonVersionSwitcherButton(ui.Button):
"""A button that allows users to re-run their eval command in a different Python version."""
def __init__(
- self,
- job_name: str,
- version_to_switch_to: Literal["3.10", "3.11"],
- snekbox_cog: Snekbox,
- ctx: Context,
- job: EvalJob,
+ self,
+ version_to_switch_to: PythonVersion,
+ snekbox_cog: Snekbox,
+ ctx: Context,
+ job: EvalJob,
) -> None:
self.version_to_switch_to = version_to_switch_to
super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)
self.snekbox_cog = snekbox_cog
self.ctx = ctx
- self.job_name = job_name
self.job = job
async def callback(self, interaction: Interaction) -> None:
@@ -244,7 +258,7 @@ class PythonVersionSwitcherButton(ui.Button):
# The log arg on send_job will stop the actual job from running.
await interaction.message.delete()
- await self.snekbox_cog.run_job(self.job_name, self.ctx, self.version_to_switch_to, self.job)
+ await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to))
class Snekbox(Cog):
@@ -255,11 +269,10 @@ class Snekbox(Cog):
self.jobs = {}
def build_python_version_switcher_view(
- self,
- job_name: str,
- current_python_version: Literal["3.10", "3.11"],
- ctx: Context,
- job: EvalJob,
+ self,
+ current_python_version: PythonVersion,
+ ctx: Context,
+ job: EvalJob,
) -> interactions.ViewWithUserAndRoleCheck:
"""Return a view that allows the user to change what version of Python their code is run on."""
if current_python_version == "3.10":
@@ -271,28 +284,25 @@ class Snekbox(Cog):
allowed_users=(ctx.author.id,),
allowed_roles=MODERATION_ROLES,
)
- view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, job))
+ view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job))
view.add_item(interactions.DeleteMessageButton())
return view
- async def post_job(
- self,
- job: EvalJob,
- python_version: Literal["3.10", "3.11"],
- ) -> dict:
+ async def post_job(self, job: EvalJob) -> EvalResult:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
- if python_version == "3.10":
+ if job.version == "3.10":
url = URLs.snekbox_eval_api
else:
url = URLs.snekbox_311_eval_api
- data = {"args": job.args, "files": [f.to_json() for f in job.files]}
+ data = job.to_dict()
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
return await resp.json()
- async def upload_output(self, output: str) -> Optional[str]:
+ @staticmethod
+ async def upload_output(output: str) -> Optional[str]:
"""Upload the job's output to a paste service and return a URL to it if successful."""
log.trace("Uploading full output to paste service...")
@@ -317,61 +327,6 @@ class Snekbox(Cog):
args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code])
return args
- @staticmethod
- def get_results_message(
- results: dict, job_name: str, python_version: Literal["3.10", "3.11"]
- ) -> Tuple[str, str, list[FileAttachment]]:
- """Return a user-friendly message and error corresponding to the process's return code."""
- stdout, returncode = results["stdout"], results["returncode"]
-
- attachments: list[FileAttachment] = []
- failed_attachments: list[str] = []
- for attachment in results.get("files", []):
- try:
- attachments.append(FileAttachment.from_dict(attachment))
- except ValueError:
- failed_attachments.append(attachment["path"])
-
- msg = f"Your {python_version} {job_name} job has completed with return code {returncode}"
- error = ""
-
- if returncode is None:
- msg = f"Your {python_version} {job_name} job has failed"
- error = stdout.strip()
- elif returncode == 128 + SIGKILL:
- msg = f"Your {python_version} {job_name} job timed out or ran out of memory"
- elif returncode == 255:
- msg = f"Your {python_version} {job_name} job has failed"
- error = "A fatal NsJail error occurred"
- else:
- # Try to append signal's name if one exists
- try:
- name = Signals(returncode - 128).name
- msg = f"{msg} ({name})"
- except ValueError:
- pass
-
- # Add error message for failed attachments
- if failed_attachments:
- failed_files = f"({', '.join(failed_attachments)})"
- msg += (
- f".\n\n> Some attached files were not able to be uploaded {failed_files}."
- f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}"
- )
-
- return msg, error, attachments
-
- @staticmethod
- def get_status_emoji(results: dict) -> str:
- """Return an emoji corresponding to the status code or lack of output in result."""
- # If there are attachments, skip empty output warning
- if not results["stdout"].strip() and not results.get("files"): # No output
- return ":warning:"
- elif results["returncode"] == 0: # No error
- return ":white_check_mark:"
- else: # Exception
- return ":x:"
-
async def format_output(self, output: str) -> Tuple[str, Optional[str]]:
"""
Format the output and return a tuple of the formatted output and a URL to the full output.
@@ -419,40 +374,32 @@ class Snekbox(Cog):
return output, paste_link
@lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True)
- async def send_job(
- self,
- job_name: str,
- ctx: Context,
- python_version: Literal["3.10", "3.11"],
- job: EvalJob,
- ) -> Message:
+ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_job(job, python_version)
- msg, error, attachments = self.get_results_message(results, job_name, python_version)
+ result = await self.post_job(job)
+ msg, error = result.message(job)
if error:
output, paste_link = error, None
else:
log.trace("Formatting output...")
- output, paste_link = await self.format_output(results["stdout"])
-
- icon = self.get_status_emoji(results)
+ output, paste_link = await self.format_output(result.stdout)
- if attachments and output in ("[No output]", ""):
- msg = f"{ctx.author.mention} {icon} {msg}.\n"
+ if result.files and output in ("[No output]", ""):
+ msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"
else:
- msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```"
+ msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n\n```\n{output}\n```"
if paste_link:
msg = f"{msg}\nFull output: {paste_link}"
# Collect stats of job fails + successes
- if icon == ":x:":
+ if result.returncode != 0:
self.bot.stats.incr("snekbox.python.fail")
else:
self.bot.stats.incr("snekbox.python.success")
@@ -465,14 +412,14 @@ class Snekbox(Cog):
response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
else:
allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
- view = self.build_python_version_switcher_view(job_name, python_version, ctx, job)
+ view = self.build_python_version_switcher_view(job.version, ctx, job)
- # Attach file if provided
- files = [atc.to_file() for atc in attachments]
+ # Attach files if provided
+ files = [f.to_file() for f in result.files]
response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files)
view.message = response
- log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")
+ log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}")
return response
async def continue_job(
@@ -549,9 +496,7 @@ class Snekbox(Cog):
async def run_job(
self,
- job_name: str,
ctx: Context,
- python_version: Literal["3.10", "3.11"],
job: EvalJob,
) -> None:
"""Handles checks, stats and re-evaluation of a snekbox job."""
@@ -571,7 +516,7 @@ class Snekbox(Cog):
while True:
try:
- response = await self.send_job(job_name, ctx, python_version, job)
+ response = await self.send_job(ctx, job)
except LockedResourceError:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
@@ -584,7 +529,7 @@ class Snekbox(Cog):
# This can happen when a button is pressed and then original code is edited and re-run.
self.jobs[ctx.message.id] = response.id
- job = await self.continue_job(ctx, response, job_name)
+ job = await self.continue_job(ctx, response, job.name)
if not job:
break
log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}")
@@ -601,7 +546,7 @@ class Snekbox(Cog):
async def eval_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
@@ -624,8 +569,8 @@ class Snekbox(Cog):
"""
code: list[str]
python_version = python_version or "3.11"
- job = EvalJob.from_code("\n".join(code))
- await self.run_job("eval", ctx, python_version, job)
+ job = EvalJob.from_code("\n".join(code)).as_version(python_version)
+ await self.run_job(ctx, job)
@command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")
@guild_only()
@@ -639,7 +584,7 @@ class Snekbox(Cog):
async def timeit_command(
self,
ctx: Context,
- python_version: Optional[Literal["3.10", "3.11"]],
+ python_version: PythonVersion | None,
*,
code: CodeblockConverter
) -> None:
@@ -663,8 +608,9 @@ class Snekbox(Cog):
code: list[str]
python_version = python_version or "3.11"
args = self.prepare_timeit_input(code)
+ job = EvalJob(args, version=python_version, name="timeit")
- await self.run_job("timeit", ctx, python_version, EvalJob(args))
+ await self.run_job(ctx, job)
def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
diff --git a/bot/exts/utils/snekio.py b/bot/exts/utils/snekio.py
new file mode 100644
index 000000000..7c5fba648
--- /dev/null
+++ b/bot/exts/utils/snekio.py
@@ -0,0 +1,64 @@
+"""I/O File protocols for snekbox."""
+from __future__ import annotations
+
+from base64 import b64decode, b64encode
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+
+from discord import File
+
+# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers
+FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB
+
+
+def sizeof_fmt(num: int, suffix: str = "B") -> str:
+ """Return a human-readable file size."""
+ for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
+ if abs(num) < 1024:
+ return f"{num:3.1f}{unit}{suffix}"
+ num /= 1024
+ return f"{num:.1f}Yi{suffix}"
+
+
+@dataclass
+class FileAttachment:
+ """File Attachment from Snekbox eval."""
+
+ path: str
+ content: bytes
+
+ def __repr__(self) -> str:
+ """Return the content as a string."""
+ content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content
+ return f"FileAttachment(path={self.path!r}, content={content})"
+
+ @classmethod
+ def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment:
+ """Create a FileAttachment from a dict response."""
+ size = data.get("size")
+ if (size and size > size_limit) or (len(data["content"]) > size_limit):
+ raise ValueError("File size exceeds limit")
+
+ content = b64decode(data["content"])
+
+ if len(content) > size_limit:
+ raise ValueError("File size exceeds limit")
+
+ return cls(data["path"], content)
+
+ def to_dict(self) -> dict[str, str]:
+ """Convert the attachment to a json dict."""
+ content = self.content
+ if isinstance(content, str):
+ content = content.encode("utf-8")
+
+ return {
+ "path": self.path,
+ "content": b64encode(content).decode("ascii"),
+ }
+
+ def to_file(self) -> File:
+ """Convert to a discord.File."""
+ name = Path(self.path).name
+ return File(BytesIO(self.content), filename=name)