diff options
| -rw-r--r-- | bot/exts/utils/snekbox.py | 55 | 
1 files changed, 47 insertions, 8 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 5e217a288..1223b89ca 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -1,7 +1,13 @@ +from __future__ import annotations +  import asyncio  import contextlib  import re +import zlib +from base64 import b64decode +from dataclasses import dataclass  from functools import partial +from io import BytesIO  from operator import attrgetter  from signal import Signals  from textwrap import dedent @@ -9,7 +15,7 @@ from typing import Literal, Optional, Tuple  from botcore.utils import interactions  from botcore.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX -from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui +from discord import AllowedMentions, File, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui  from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only  from bot.bot import Bot @@ -79,6 +85,28 @@ REDO_EMOJI = '\U0001f501'  # :repeat:  REDO_TIMEOUT = 30 +@dataclass +class FileAttachment: +    """File Attachment from Snekbox eval.""" + +    name: str +    mime: str +    content: bytes + +    @classmethod +    def from_dict(cls, data: dict) -> FileAttachment: +        """Create a FileAttachment from a dict response.""" +        return cls( +            data["name"], +            data["mime"], +            zlib.decompress(b64decode(data["content"])), +        ) + +    def to_file(self) -> File: +        """Convert to a discord.File.""" +        return File(BytesIO(self.content), filename=self.name) + +  class CodeblockConverter(Converter):      """Attempts to extract code from a codeblock, if provided.""" @@ -171,7 +199,7 @@ class Snekbox(Cog):          ctx: Context,          code: str,          args: Optional[list[str]] = None -    ) -> None: +    ) -> interactions.ViewWithUserAndRoleCheck:          """Return a view that allows the user to change what version of Python their code is run on."""          if current_python_version == "3.10":              alt_python_version = "3.11" @@ -238,9 +266,12 @@ class Snekbox(Cog):          return code, args      @staticmethod -    def get_results_message(results: dict, job_name: str, python_version: Literal["3.10", "3.11"]) -> Tuple[str, str]: +    def get_results_message( +            results: dict, job_name: str, python_version: Literal["3.10", "3.11"] +    ) -> Tuple[str, str, list[FileAttachment]]:          """Return a user-friendly message and error corresponding to the process's return code."""          stdout, returncode = results["stdout"], results["returncode"] +        attachments = [FileAttachment.from_dict(d) for d in results["attachments"]]          msg = f"Your {python_version} {job_name} job has completed with return code {returncode}"          error = "" @@ -260,12 +291,12 @@ class Snekbox(Cog):              except ValueError:                  pass -        return msg, error +        return msg, error, attachments      @staticmethod      def get_status_emoji(results: dict) -> str:          """Return an emoji corresponding to the status code or lack of output in result.""" -        if not results["stdout"].strip():  # No output +        if not results["stdout"].strip() and not results["attachments"]:  # No output              return ":warning:"          elif results["returncode"] == 0:  # No error              return ":white_check_mark:" @@ -335,7 +366,7 @@ class Snekbox(Cog):          """          async with ctx.typing():              results = await self.post_job(code, python_version, args=args) -            msg, error = self.get_results_message(results, job_name, python_version) +            msg, error, attachments = self.get_results_message(results, job_name, python_version)              if error:                  output, paste_link = error, None @@ -344,7 +375,12 @@ class Snekbox(Cog):                  output, paste_link = await self.format_output(results["stdout"])              icon = self.get_status_emoji(results) -            msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" + +            if attachments and output in ("[No output]", ""): +                msg = f"{ctx.author.mention} {icon} {msg}.\n" +            else: +                msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" +              if paste_link:                  msg = f"{msg}\nFull output: {paste_link}" @@ -363,7 +399,10 @@ class Snekbox(Cog):              else:                  allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])                  view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args) -                response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view) + +                # Attach file if provided +                files = [atc.to_file() for atc in attachments] +                response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files)                  view.message = response              log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")  |