diff options
-rw-r--r-- | bot/exts/utils/snekbox.py | 296 | ||||
-rw-r--r-- | bot/exts/utils/snekio.py | 64 |
2 files changed, 185 insertions, 175 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 75c6f2d3a..d8a3088a6 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -3,25 +3,22 @@ from __future__ import annotations import asyncio import contextlib import re -from base64 import b64decode, b64encode -from collections.abc import Iterable from dataclasses import dataclass, field from functools import partial -from io import BytesIO from operator import attrgetter -from pathlib import Path from signal import Signals from textwrap import dedent -from typing import Generic, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar +from typing import Literal, Optional, TYPE_CHECKING, Tuple from botcore.utils import interactions from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from discord import AllowedMentions, File, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui +from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, MODERATION_ROLES, Roles, URLs from bot.decorators import redirect_output +from bot.exts.utils.snekio import FileAttachment, sizeof_fmt, FILE_SIZE_LIMIT from bot.log import get_logger from bot.utils import send_to_paste_service from bot.utils.lock import LockedResourceError, lock_arg @@ -88,89 +85,108 @@ SIGKILL = 9 REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 -# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers -FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB - -T = TypeVar("T") - - -def sizeof_fmt(num: int, suffix: str = "B") -> str: - """Return a human-readable file size.""" - for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): - if abs(num) < 1024: - return f"{num:3.1f}{unit}{suffix}" - num /= 1024 - return f"{num:.1f}Yi{suffix}" +PythonVersion = Literal["3.10", "3.11"] @dataclass -class FileAttachment(Generic[T]): - """File Attachment from Snekbox eval.""" - - path: str - content: T +class EvalJob: + """Job to be evaluated by snekbox.""" - def __repr__(self) -> str: - """Return the content as a string.""" - content = self.content if isinstance(self.content, str) else "(...)" - return f"FileAttachment(path={self.path}, content={content})" + args: list[str] + files: list[FileAttachment] = field(default_factory=list) + name: str = "eval" + version: PythonVersion = "3.11" @classmethod - def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment[bytes]: - """Create a FileAttachment from a dict response.""" - size = data.get("size") - if (size and size > size_limit) or (len(data["content"]) > size_limit): - raise ValueError("File size exceeds limit") + def from_code(cls, code: str, path: str = "main.py") -> EvalJob: + """Create an EvalJob from a code string.""" + return cls( + args=[path], + files=[FileAttachment(path, code.encode())], + ) - content = b64decode(data["content"]) + def as_version(self, version: PythonVersion) -> EvalJob: + """Return a copy of the job with a different Python version.""" + return EvalJob( + args=self.args, + files=self.files, + name=self.name, + version=version, + ) - if len(content) > size_limit: - raise ValueError("File size exceeds limit") + def to_dict(self) -> dict[str, list[str | dict[str, str]]]: + """Convert the job to a dict.""" + return { + "args": self.args, + "files": [file.to_dict() for file in self.files], + } - return cls(data["path"], content) - def to_json(self) -> dict[str, str]: - """Convert the attachment to a json dict.""" - content = self.content - if isinstance(content, str): - content = content.encode("utf-8") +@dataclass(frozen=True) +class EvalResult: + """The result of an eval job.""" - return { - "path": self.path, - "content": b64encode(content).decode("ascii"), - } + stdout: str + returncode: int | None + files: list[FileAttachment] = field(default_factory=list) + err_files: list[str] = field(default_factory=list) - def to_file(self) -> File: - """Convert to a discord.File.""" - name = Path(self.path).name - return File(BytesIO(self.content), filename=name) + @property + def status_emoji(self): + """Return an emoji corresponding to the status code or lack of output in result.""" + # If there are attachments, skip empty output warning + if not self.stdout.strip() and not self.files: # No output + return ":warning:" + elif self.returncode == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + def message(self, job: EvalJob) -> tuple[str, str]: + """Return a user-friendly message and error corresponding to the process's return code.""" + msg = f"Your {job.version} {job.name} job has completed with return code {self.returncode}" + error = "" -@dataclass -class EvalJob: - """Represents a job to be evaluated by Snekbox.""" + if self.returncode is None: + msg = f"Your {job.version} {job.name} job has failed" + error = self.stdout.strip() + elif self.returncode == 128 + SIGKILL: + msg = f"Your {job.version} {job.name} job timed out or ran out of memory" + elif self.returncode == 255: + msg = f"Your {job.version} {job.name} job has failed" + error = "A fatal NsJail error occurred" + else: + # Try to append signal's name if one exists + with contextlib.suppress(ValueError): + name = Signals(self.returncode - 128).name + msg = f"{msg} ({name})" - args: list[str] - files: list[FileAttachment] = field(default_factory=list) + # Add error message for failed attachments + if self.err_files: + failed_files = f"({', '.join(self.err_files)})" + msg += ( + f".\n\n> Some attached files were not able to be uploaded {failed_files}." + f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" + ) - def __repr__(self) -> str: - """Return the job as a string.""" - return f"EvalJob(args={self.args}, files={self.files})" + return msg, error @classmethod - def from_code(cls, code: str, files: Iterable[FileAttachment] = (), name: str = "main.py") -> EvalJob: - """Create an EvalJob from a code string.""" - return cls( - args=[name], - files=[FileAttachment(name, code), *files], + def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: + """Create an EvalResult from a dict.""" + res = cls( + stdout=data["stdout"], + returncode=data["returncode"], ) - def to_json(self) -> dict[str, list[str | dict[str, str]]]: - """Convert the job to a dict.""" - return { - "args": self.args, - "files": [file.to_json() for file in self.files], - } + for file in data.get("files", []): + try: + res.files.append(FileAttachment.from_dict(file)) + except ValueError as e: + log.info(f"Failed to parse file from snekbox response: {e}") + res.err_files.append(file["path"]) + + return res class CodeblockConverter(Converter): @@ -214,19 +230,17 @@ class PythonVersionSwitcherButton(ui.Button): """A button that allows users to re-run their eval command in a different Python version.""" def __init__( - self, - job_name: str, - version_to_switch_to: Literal["3.10", "3.11"], - snekbox_cog: Snekbox, - ctx: Context, - job: EvalJob, + self, + version_to_switch_to: PythonVersion, + snekbox_cog: Snekbox, + ctx: Context, + job: EvalJob, ) -> None: self.version_to_switch_to = version_to_switch_to super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary) self.snekbox_cog = snekbox_cog self.ctx = ctx - self.job_name = job_name self.job = job async def callback(self, interaction: Interaction) -> None: @@ -244,7 +258,7 @@ class PythonVersionSwitcherButton(ui.Button): # The log arg on send_job will stop the actual job from running. await interaction.message.delete() - await self.snekbox_cog.run_job(self.job_name, self.ctx, self.version_to_switch_to, self.job) + await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to)) class Snekbox(Cog): @@ -255,11 +269,10 @@ class Snekbox(Cog): self.jobs = {} def build_python_version_switcher_view( - self, - job_name: str, - current_python_version: Literal["3.10", "3.11"], - ctx: Context, - job: EvalJob, + self, + current_python_version: PythonVersion, + ctx: Context, + job: EvalJob, ) -> interactions.ViewWithUserAndRoleCheck: """Return a view that allows the user to change what version of Python their code is run on.""" if current_python_version == "3.10": @@ -271,28 +284,25 @@ class Snekbox(Cog): allowed_users=(ctx.author.id,), allowed_roles=MODERATION_ROLES, ) - view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, job)) + view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job)) view.add_item(interactions.DeleteMessageButton()) return view - async def post_job( - self, - job: EvalJob, - python_version: Literal["3.10", "3.11"], - ) -> dict: + async def post_job(self, job: EvalJob) -> EvalResult: """Send a POST request to the Snekbox API to evaluate code and return the results.""" - if python_version == "3.10": + if job.version == "3.10": url = URLs.snekbox_eval_api else: url = URLs.snekbox_311_eval_api - data = {"args": job.args, "files": [f.to_json() for f in job.files]} + data = job.to_dict() async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: return await resp.json() - async def upload_output(self, output: str) -> Optional[str]: + @staticmethod + async def upload_output(output: str) -> Optional[str]: """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") @@ -317,61 +327,6 @@ class Snekbox(Cog): args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) return args - @staticmethod - def get_results_message( - results: dict, job_name: str, python_version: Literal["3.10", "3.11"] - ) -> Tuple[str, str, list[FileAttachment]]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - - attachments: list[FileAttachment] = [] - failed_attachments: list[str] = [] - for attachment in results.get("files", []): - try: - attachments.append(FileAttachment.from_dict(attachment)) - except ValueError: - failed_attachments.append(attachment["path"]) - - msg = f"Your {python_version} {job_name} job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = f"Your {python_version} {job_name} job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = f"Your {python_version} {job_name} job timed out or ran out of memory" - elif returncode == 255: - msg = f"Your {python_version} {job_name} job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - # Add error message for failed attachments - if failed_attachments: - failed_files = f"({', '.join(failed_attachments)})" - msg += ( - f".\n\n> Some attached files were not able to be uploaded {failed_files}." - f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" - ) - - return msg, error, attachments - - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - # If there are attachments, skip empty output warning - if not results["stdout"].strip() and not results.get("files"): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: """ Format the output and return a tuple of the formatted output and a URL to the full output. @@ -419,40 +374,32 @@ class Snekbox(Cog): return output, paste_link @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) - async def send_job( - self, - job_name: str, - ctx: Context, - python_version: Literal["3.10", "3.11"], - job: EvalJob, - ) -> Message: + async def send_job(self, ctx: Context, job: EvalJob) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. Return the bot response. """ async with ctx.typing(): - results = await self.post_job(job, python_version) - msg, error, attachments = self.get_results_message(results, job_name, python_version) + result = await self.post_job(job) + msg, error = result.message(job) if error: output, paste_link = error, None else: log.trace("Formatting output...") - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) + output, paste_link = await self.format_output(result.stdout) - if attachments and output in ("[No output]", ""): - msg = f"{ctx.author.mention} {icon} {msg}.\n" + if result.files and output in ("[No output]", ""): + msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" else: - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n\n```\n{output}\n```" if paste_link: msg = f"{msg}\nFull output: {paste_link}" # Collect stats of job fails + successes - if icon == ":x:": + if result.returncode != 0: self.bot.stats.incr("snekbox.python.fail") else: self.bot.stats.incr("snekbox.python.success") @@ -465,14 +412,14 @@ class Snekbox(Cog): response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) - view = self.build_python_version_switcher_view(job_name, python_version, ctx, job) + view = self.build_python_version_switcher_view(job.version, ctx, job) - # Attach file if provided - files = [atc.to_file() for atc in attachments] + # Attach files if provided + files = [f.to_file() for f in result.files] response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) view.message = response - log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") + log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}") return response async def continue_job( @@ -549,9 +496,7 @@ class Snekbox(Cog): async def run_job( self, - job_name: str, ctx: Context, - python_version: Literal["3.10", "3.11"], job: EvalJob, ) -> None: """Handles checks, stats and re-evaluation of a snekbox job.""" @@ -571,7 +516,7 @@ class Snekbox(Cog): while True: try: - response = await self.send_job(job_name, ctx, python_version, job) + response = await self.send_job(ctx, job) except LockedResourceError: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -584,7 +529,7 @@ class Snekbox(Cog): # This can happen when a button is pressed and then original code is edited and re-run. self.jobs[ctx.message.id] = response.id - job = await self.continue_job(ctx, response, job_name) + job = await self.continue_job(ctx, response, job.name) if not job: break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}") @@ -601,7 +546,7 @@ class Snekbox(Cog): async def eval_command( self, ctx: Context, - python_version: Optional[Literal["3.10", "3.11"]], + python_version: PythonVersion | None, *, code: CodeblockConverter ) -> None: @@ -624,8 +569,8 @@ class Snekbox(Cog): """ code: list[str] python_version = python_version or "3.11" - job = EvalJob.from_code("\n".join(code)) - await self.run_job("eval", ctx, python_version, job) + job = EvalJob.from_code("\n".join(code)).as_version(python_version) + await self.run_job(ctx, job) @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>") @guild_only() @@ -639,7 +584,7 @@ class Snekbox(Cog): async def timeit_command( self, ctx: Context, - python_version: Optional[Literal["3.10", "3.11"]], + python_version: PythonVersion | None, *, code: CodeblockConverter ) -> None: @@ -663,8 +608,9 @@ class Snekbox(Cog): code: list[str] python_version = python_version or "3.11" args = self.prepare_timeit_input(code) + job = EvalJob(args, version=python_version, name="timeit") - await self.run_job("timeit", ctx, python_version, EvalJob(args)) + await self.run_job(ctx, job) def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: diff --git a/bot/exts/utils/snekio.py b/bot/exts/utils/snekio.py new file mode 100644 index 000000000..7c5fba648 --- /dev/null +++ b/bot/exts/utils/snekio.py @@ -0,0 +1,64 @@ +"""I/O File protocols for snekbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path + +from discord import File + +# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB + + +def sizeof_fmt(num: int, suffix: str = "B") -> str: + """Return a human-readable file size.""" + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024 + return f"{num:.1f}Yi{suffix}" + + +@dataclass +class FileAttachment: + """File Attachment from Snekbox eval.""" + + path: str + content: bytes + + def __repr__(self) -> str: + """Return the content as a string.""" + content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content + return f"FileAttachment(path={self.path!r}, content={content})" + + @classmethod + def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: + """Create a FileAttachment from a dict response.""" + size = data.get("size") + if (size and size > size_limit) or (len(data["content"]) > size_limit): + raise ValueError("File size exceeds limit") + + content = b64decode(data["content"]) + + if len(content) > size_limit: + raise ValueError("File size exceeds limit") + + return cls(data["path"], content) + + def to_dict(self) -> dict[str, str]: + """Convert the attachment to a json dict.""" + content = self.content + if isinstance(content, str): + content = content.encode("utf-8") + + return { + "path": self.path, + "content": b64encode(content).decode("ascii"), + } + + def to_file(self) -> File: + """Convert to a discord.File.""" + name = Path(self.path).name + return File(BytesIO(self.content), filename=name) |