diff options
author | 2022-11-20 21:43:30 -0500 | |
---|---|---|
committer | 2022-11-20 21:43:30 -0500 | |
commit | 499336bd818cc4b9c1c4f5bb3c04a75f1730ada4 (patch) | |
tree | 54f9e9ae6c78166afde06bfafd0526f597cc15d9 | |
parent | Add support for displaying files from snekbox (diff) |
Implement full FileAttachment parsing
-rw-r--r-- | bot/exts/utils/snekbox.py | 194 |
1 files changed, 135 insertions, 59 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 1223b89ca..93941ed4c 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -4,14 +4,15 @@ import asyncio import contextlib import re import zlib -from base64 import b64decode -from dataclasses import dataclass +from base64 import b64decode, b64encode +from collections.abc import Iterable +from dataclasses import dataclass, field from functools import partial from io import BytesIO from operator import attrgetter from signal import Signals from textwrap import dedent -from typing import Literal, Optional, Tuple +from typing import Generic, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar from botcore.utils import interactions from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX @@ -26,6 +27,9 @@ from bot.utils import send_to_paste_service from bot.utils.lock import LockedResourceError, lock_arg from bot.utils.services import PasteTooLongError, PasteUploadError +if TYPE_CHECKING: + from bot.exts.filters.filtering import Filtering + log = get_logger(__name__) ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") @@ -84,29 +88,100 @@ SIGKILL = 9 REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 +# Note discord upload limit is 8 MB, or 50 MB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024 # 8 MiB + +T = TypeVar("T") + + +def sizeof_fmt(num: int, suffix: str = "B") -> str: + """Return a human-readable file size.""" + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024 + return f"{num:.1f}Yi{suffix}" + @dataclass -class FileAttachment: +class FileAttachment(Generic[T]): """File Attachment from Snekbox eval.""" name: str - mime: str - content: bytes + content: T + + def __repr__(self) -> str: + """Return the content as a string.""" + content = self.content if isinstance(self.content, str) else "(...)" + return f"FileAttachment(name={self.name}, content={content})" @classmethod - def from_dict(cls, data: dict) -> FileAttachment: + def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment[bytes]: """Create a FileAttachment from a dict response.""" - return cls( - data["name"], - data["mime"], - zlib.decompress(b64decode(data["content"])), - ) + size = data.get("size") + if (size and size > size_limit) or (len(data["content"]) > size_limit): + raise ValueError("File size exceeds limit") + + match data.get("content-encoding"): + case "base64+zlib": + content = zlib.decompress(b64decode(data["content"])) + case "base64": + content = b64decode(data["content"]) + case _: + content = data["content"] + + if len(content) > size_limit: + raise ValueError("File size exceeds limit") + + return cls(data["name"], content) + + def to_json(self) -> dict[str, str]: + """Convert the attachment to a json dict.""" + if isinstance(self.content, bytes): + content = b64encode(self.content).decode("ascii") + encoding = "base64" + else: + content = self.content + encoding = "" + + return { + "name": self.name, + "content-encoding": encoding, + "content": content, + } def to_file(self) -> File: """Convert to a discord.File.""" return File(BytesIO(self.content), filename=self.name) +@dataclass +class EvalJob: + """Represents a job to be evaluated by Snekbox.""" + + args: list[str] + files: list[FileAttachment] = field(default_factory=list) + + def __str__(self) -> str: + """Return the job as a string.""" + return f"EvalJob(args={self.args}, files={self.files})" + + @classmethod + def from_code(cls, code: str, files: Iterable[FileAttachment] = (), name: str = "main.py") -> EvalJob: + """Create an EvalJob from a code string.""" + return cls( + args=[name], + files=[FileAttachment(name, code), *files], + ) + + def to_json(self) -> dict[str, list[str | dict[str, str]]]: + """Convert the job to a dict.""" + return { + "args": self.args, + "files": [file.to_json() for file in self.files], + } + + class CodeblockConverter(Converter): """Attempts to extract code from a codeblock, if provided.""" @@ -151,10 +226,9 @@ class PythonVersionSwitcherButton(ui.Button): self, job_name: str, version_to_switch_to: Literal["3.10", "3.11"], - snekbox_cog: "Snekbox", + snekbox_cog: Snekbox, ctx: Context, - code: str, - args: Optional[list[str]] = None + job: EvalJob, ) -> None: self.version_to_switch_to = version_to_switch_to super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary) @@ -162,8 +236,7 @@ class PythonVersionSwitcherButton(ui.Button): self.snekbox_cog = snekbox_cog self.ctx = ctx self.job_name = job_name - self.code = code - self.args = args + self.job = job async def callback(self, interaction: Interaction) -> None: """ @@ -176,13 +249,11 @@ class PythonVersionSwitcherButton(ui.Button): await interaction.response.defer() with contextlib.suppress(NotFound): - # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button. + # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button. # The log arg on send_job will stop the actual job from running. await interaction.message.delete() - await self.snekbox_cog.run_job( - self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args - ) + await self.snekbox_cog.run_job(self.job_name, self.ctx, self.version_to_switch_to, self.job) class Snekbox(Cog): @@ -197,8 +268,7 @@ class Snekbox(Cog): job_name: str, current_python_version: Literal["3.10", "3.11"], ctx: Context, - code: str, - args: Optional[list[str]] = None + job: EvalJob, ) -> interactions.ViewWithUserAndRoleCheck: """Return a view that allows the user to change what version of Python their code is run on.""" if current_python_version == "3.10": @@ -210,17 +280,15 @@ class Snekbox(Cog): allowed_users=(ctx.author.id,), allowed_roles=MODERATION_ROLES, ) - view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args)) + view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, job)) view.add_item(interactions.DeleteMessageButton()) return view async def post_job( self, - code: str, + job: EvalJob, python_version: Literal["3.10", "3.11"], - *, - args: Optional[list[str]] = None ) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" if python_version == "3.10": @@ -228,10 +296,7 @@ class Snekbox(Cog): else: url = URLs.snekbox_311_eval_api - data = {"input": code} - - if args is not None: - data["args"] = args + data = {"args": job.args, "files": [f.to_json() for f in job.files]} async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: return await resp.json() @@ -248,30 +313,34 @@ class Snekbox(Cog): return "unable to upload" @staticmethod - def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: + def prepare_timeit_input(codeblocks: list[str]) -> list[str]: """ Join the codeblocks into a single string, then return the code and the arguments in a tuple. If there are multiple codeblocks, insert the first one into the wrapped setup code. """ args = ["-m", "timeit"] - setup = "" - if len(codeblocks) > 1: - setup = codeblocks.pop(0) - + setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else "" code = "\n".join(codeblocks) - args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) - - return code, args + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) + return args @staticmethod def get_results_message( - results: dict, job_name: str, python_version: Literal["3.10", "3.11"] + results: dict, job_name: str, python_version: Literal["3.10", "3.11"] ) -> Tuple[str, str, list[FileAttachment]]: """Return a user-friendly message and error corresponding to the process's return code.""" stdout, returncode = results["stdout"], results["returncode"] - attachments = [FileAttachment.from_dict(d) for d in results["attachments"]] + + attachments: list[FileAttachment] = [] + failed_attachments: list[str] = [] + for attachment in results["attachments"]: + try: + attachments.append(FileAttachment.from_dict(attachment)) + except ValueError: + failed_attachments.append(attachment["name"]) + msg = f"Your {python_version} {job_name} job has completed with return code {returncode}" error = "" @@ -291,6 +360,14 @@ class Snekbox(Cog): except ValueError: pass + # Add error message for failed attachments + if failed_attachments: + failed_files = f"({', '.join(failed_attachments)})" + msg += ( + f".\n\n> Some attached files were not able to be uploaded {failed_files}." + f" Check that the file size is less than {sizeof_fmt(FILE_SIZE_LIMIT)}" + ) + return msg, error, attachments @staticmethod @@ -352,12 +429,10 @@ class Snekbox(Cog): @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) async def send_job( self, + job_name: str, ctx: Context, python_version: Literal["3.10", "3.11"], - code: str, - *, - args: Optional[list[str]] = None, - job_name: str + job: EvalJob, ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. @@ -365,7 +440,7 @@ class Snekbox(Cog): Return the bot response. """ async with ctx.typing(): - results = await self.post_job(code, python_version, args=args) + results = await self.post_job(job, python_version) msg, error, attachments = self.get_results_message(results, job_name, python_version) if error: @@ -390,7 +465,7 @@ class Snekbox(Cog): else: self.bot.stats.incr("snekbox.python.success") - filter_cog = self.bot.get_cog("Filtering") + filter_cog: Filtering | None = self.bot.get_cog("Filtering") filter_triggered = False if filter_cog: filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) @@ -398,7 +473,7 @@ class Snekbox(Cog): response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) - view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args) + view = self.build_python_version_switcher_view(job_name, python_version, ctx, job) # Attach file if provided files = [atc.to_file() for atc in attachments] @@ -485,9 +560,7 @@ class Snekbox(Cog): job_name: str, ctx: Context, python_version: Literal["3.10", "3.11"], - code: str, - *, - args: Optional[list[str]] = None, + job: EvalJob, ) -> None: """Handles checks, stats and re-evaluation of a snekbox job.""" if Roles.helpers in (role.id for role in ctx.author.roles): @@ -502,11 +575,11 @@ class Snekbox(Cog): else: self.bot.stats.incr("snekbox_usages.channels.topical") - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + log.info(f"Received code from {ctx.author} for evaluation:\n{job}") while True: try: - response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name) + response = await self.send_job(job_name, ctx, python_version, job) except LockedResourceError: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -514,7 +587,7 @@ class Snekbox(Cog): ) return - # Store the bot's response message id per invocation, to ensure the `wait_for`s in `continue_job` + # Store the bots response message id per invocation, to ensure the `wait_for`s in `continue_job` # don't trigger if the response has already been replaced by a new response. # This can happen when a button is pressed and then original code is edited and re-run. self.jobs[ctx.message.id] = response.id @@ -548,17 +621,19 @@ class Snekbox(Cog): clicking the reaction that subsequently appears. If multiple codeblocks are in a message, all of them will be joined and evaluated, - ignoring the text outside of them. + ignoring the text outside them. - By default your code is run on Python's 3.11 beta release, to assist with testing. If you + By default, your code is run on Python's 3.11 beta release, to assist with testing. If you run into issues related to this Python version, you can request the bot to use Python 3.10 by specifying the `python_version` arg and setting it to `3.10`. We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ + code: list[str] python_version = python_version or "3.11" - await self.run_job("eval", ctx, python_version, "\n".join(code)) + job = EvalJob.from_code("\n".join(code)) + await self.run_job("eval", ctx, python_version, job) @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>") @guild_only() @@ -593,10 +668,11 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ + code: list[str] python_version = python_version or "3.11" - code, args = self.prepare_timeit_input(code) + args = self.prepare_timeit_input(code) - await self.run_job("timeit", ctx, python_version, code=code, args=args) + await self.run_job("timeit", ctx, python_version, EvalJob(args)) def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: |