diff options
| -rw-r--r-- | bot/constants.py | 1 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox/__init__.py | 12 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox/_cog.py (renamed from bot/exts/utils/snekbox.py) | 373 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox/_eval.py | 183 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox/_io.py | 102 | ||||
| -rw-r--r-- | tests/bot/exts/utils/snekbox/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/utils/snekbox/test_io.py | 34 | ||||
| -rw-r--r-- | tests/bot/exts/utils/snekbox/test_snekbox.py (renamed from tests/bot/exts/utils/test_snekbox.py) | 266 | 
8 files changed, 727 insertions, 244 deletions
diff --git a/bot/constants.py b/bot/constants.py index a4d5761be..3aacd0a16 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -584,6 +584,7 @@ class _Emojis(EnvConfig):      defcon_update = "<:defconsettingsupdated:470326274082996224>"  # noqa: E704      failmail = "<:failmail:633660039931887616>" +    failed_file = "<:failed_file:1073298441968562226>"      incident_actioned = "<:incident_actioned:714221559279255583>"      incident_investigating = "<:incident_investigating:714224190928191551>" diff --git a/bot/exts/utils/snekbox/__init__.py b/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..cd1d3b059 --- /dev/null +++ b/bot/exts/utils/snekbox/__init__.py @@ -0,0 +1,12 @@ +from bot.bot import Bot +from bot.exts.utils.snekbox._cog import CodeblockConverter, Snekbox +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult + +__all__ = ("CodeblockConverter", "Snekbox", "EvalJob", "EvalResult") + + +async def setup(bot: Bot) -> None: +    """Load the Snekbox cog.""" +    # Defer import to reduce side effects from importing the codeblock package. +    from bot.exts.utils.snekbox._cog import Snekbox +    await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox/_cog.py index ddcbe01fa..b48fcf592 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -1,11 +1,12 @@ +from __future__ import annotations +  import asyncio  import contextlib  import re  from functools import partial  from operator import attrgetter -from signal import Signals  from textwrap import dedent -from typing import Literal, Optional, Tuple +from typing import Literal, NamedTuple, Optional, TYPE_CHECKING  from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui  from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only @@ -13,14 +14,21 @@ from pydis_core.utils import interactions  from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX  from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, Roles, URLs +from bot.constants import Channels, Emojis, Filter, MODERATION_ROLES, Roles, URLs  from bot.decorators import redirect_output +from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME +from bot.exts.filters.antimalware import TXT_LIKE_FILES  from bot.exts.help_channels._channel import is_help_forum_post +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult +from bot.exts.utils.snekbox._io import FileAttachment  from bot.log import get_logger  from bot.utils import send_to_paste_service  from bot.utils.lock import LockedResourceError, lock_arg  from bot.utils.services import PasteTooLongError, PasteUploadError +if TYPE_CHECKING: +    from bot.exts.filters.filtering import Filtering +  log = get_logger(__name__)  ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") @@ -68,17 +76,23 @@ if not hasattr(sys, "_setup_finished"):  """  MAX_PASTE_LENGTH = 10_000 +# Max to display in a codeblock before sending to a paste service +# This also applies to text files +MAX_OUTPUT_BLOCK_LINES = 10 +MAX_OUTPUT_BLOCK_CHARS = 1000  # The Snekbox commands' whitelists and blacklists.  NO_SNEKBOX_CHANNELS = (Channels.python_general,)  NO_SNEKBOX_CATEGORIES = ()  SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) -SIGKILL = 9 -  REDO_EMOJI = '\U0001f501'  # :repeat:  REDO_TIMEOUT = 30 +PythonVersion = Literal["3.10", "3.11"] + +FilteredFiles = NamedTuple("FilteredFiles", [("allowed", list[FileAttachment]), ("blocked", list[FileAttachment])]) +  class CodeblockConverter(Converter):      """Attempts to extract code from a codeblock, if provided.""" @@ -122,21 +136,17 @@ class PythonVersionSwitcherButton(ui.Button):      def __init__(          self, -        job_name: str, -        version_to_switch_to: Literal["3.10", "3.11"], -        snekbox_cog: "Snekbox", +        version_to_switch_to: PythonVersion, +        snekbox_cog: Snekbox,          ctx: Context, -        code: str, -        args: Optional[list[str]] = None +        job: EvalJob,      ) -> None:          self.version_to_switch_to = version_to_switch_to          super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary)          self.snekbox_cog = snekbox_cog          self.ctx = ctx -        self.job_name = job_name -        self.code = code -        self.args = args +        self.job = job      async def callback(self, interaction: Interaction) -> None:          """ @@ -149,13 +159,11 @@ class PythonVersionSwitcherButton(ui.Button):          await interaction.response.defer()          with contextlib.suppress(NotFound): -            # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button. +            # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button.              # The log arg on send_job will stop the actual job from running.              await interaction.message.delete() -        await self.snekbox_cog.run_job( -            self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args -        ) +        await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to))  class Snekbox(Cog): @@ -167,13 +175,12 @@ class Snekbox(Cog):      def build_python_version_switcher_view(          self, -        job_name: str, -        current_python_version: Literal["3.10", "3.11"], +        current_python_version: PythonVersion,          ctx: Context, -        code: str, -        args: Optional[list[str]] = None -    ) -> None: +        job: EvalJob, +    ) -> interactions.ViewWithUserAndRoleCheck:          """Return a view that allows the user to change what version of Python their code is run on.""" +        alt_python_version: PythonVersion          if current_python_version == "3.10":              alt_python_version = "3.11"          else: @@ -183,33 +190,25 @@ class Snekbox(Cog):              allowed_users=(ctx.author.id,),              allowed_roles=MODERATION_ROLES,          ) -        view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args)) +        view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job))          view.add_item(interactions.DeleteMessageButton())          return view -    async def post_job( -        self, -        code: str, -        python_version: Literal["3.10", "3.11"], -        *, -        args: Optional[list[str]] = None -    ) -> dict: +    async def post_job(self, job: EvalJob) -> EvalResult:          """Send a POST request to the Snekbox API to evaluate code and return the results.""" -        if python_version == "3.10": +        if job.version == "3.10":              url = URLs.snekbox_eval_api          else:              url = URLs.snekbox_311_eval_api -        data = {"input": code} - -        if args is not None: -            data["args"] = args +        data = job.to_dict()          async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: -            return await resp.json() +            return EvalResult.from_dict(await resp.json()) -    async def upload_output(self, output: str) -> Optional[str]: +    @staticmethod +    async def upload_output(output: str) -> Optional[str]:          """Upload the job's output to a paste service and return a URL to it if successful."""          log.trace("Uploading full output to paste service...") @@ -221,59 +220,27 @@ class Snekbox(Cog):              return "unable to upload"      @staticmethod -    def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: +    def prepare_timeit_input(codeblocks: list[str]) -> list[str]:          """ -        Join the codeblocks into a single string, then return the code and the arguments in a tuple. +        Join the codeblocks into a single string, then return the arguments in a list.          If there are multiple codeblocks, insert the first one into the wrapped setup code.          """          args = ["-m", "timeit"] -        setup = "" -        if len(codeblocks) > 1: -            setup = codeblocks.pop(0) - +        setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else ""          code = "\n".join(codeblocks) -        args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) - -        return code, args - -    @staticmethod -    def get_results_message(results: dict, job_name: str, python_version: Literal["3.10", "3.11"]) -> Tuple[str, str]: -        """Return a user-friendly message and error corresponding to the process's return code.""" -        stdout, returncode = results["stdout"], results["returncode"] -        msg = f"Your {python_version} {job_name} job has completed with return code {returncode}" -        error = "" - -        if returncode is None: -            msg = f"Your {python_version} {job_name} job has failed" -            error = stdout.strip() -        elif returncode == 128 + SIGKILL: -            msg = f"Your {python_version} {job_name} job timed out or ran out of memory" -        elif returncode == 255: -            msg = f"Your {python_version} {job_name} job has failed" -            error = "A fatal NsJail error occurred" -        else: -            # Try to append signal's name if one exists -            try: -                name = Signals(returncode - 128).name -                msg = f"{msg} ({name})" -            except ValueError: -                pass - -        return msg, error +        args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) +        return args -    @staticmethod -    def get_status_emoji(results: dict) -> str: -        """Return an emoji corresponding to the status code or lack of output in result.""" -        if not results["stdout"].strip():  # No output -            return ":warning:" -        elif results["returncode"] == 0:  # No error -            return ":white_check_mark:" -        else:  # Exception -            return ":x:" - -    async def format_output(self, output: str) -> Tuple[str, Optional[str]]: +    async def format_output( +        self, +        output: str, +        max_lines: int = MAX_OUTPUT_BLOCK_LINES, +        max_chars: int = MAX_OUTPUT_BLOCK_CHARS, +        line_nums: bool = True, +        output_default: str = "[No output]", +    ) -> tuple[str, str | None]:          """          Format the output and return a tuple of the formatted output and a URL to the full output. @@ -295,96 +262,182 @@ class Snekbox(Cog):              return "Code block escape attempt detected; will not output result", paste_link          truncated = False -        lines = output.count("\n") +        lines = output.splitlines() -        if lines > 0: -            output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] -            output = output[:11]  # Limiting to only 11 lines -            output = "\n".join(output) +        if len(lines) > 1: +            if line_nums: +                lines = [f"{i:03d} | {line}" for i, line in enumerate(lines, 1)] +            lines = lines[:max_lines+1]  # Limiting to max+1 lines +            output = "\n".join(lines) -        if lines > 10: +        if len(lines) > max_lines:              truncated = True -            if len(output) >= 1000: -                output = f"{output[:1000]}\n... (truncated - too long, too many lines)" +            if len(output) >= max_chars: +                output = f"{output[:max_chars]}\n... (truncated - too long, too many lines)"              else:                  output = f"{output}\n... (truncated - too many lines)" -        elif len(output) >= 1000: +        elif len(output) >= max_chars:              truncated = True -            output = f"{output[:1000]}\n... (truncated - too long)" +            output = f"{output[:max_chars]}\n... (truncated - too long)"          if truncated:              paste_link = await self.upload_output(original_output) -        output = output or "[No output]" +        if output_default and not output: +            output = output_default          return output, paste_link +    def get_extensions_whitelist(self) -> set[str]: +        """Return a set of whitelisted file extensions.""" +        return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES + +    def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles: +        """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" +        # Check if user is staff, if is, return +        # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance +        if hasattr(ctx.author, "roles") and any(role.id in Filter.role_whitelist for role in ctx.author.roles): +            return FilteredFiles(files, []) +        # Ignore code jam channels +        if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME: +            return FilteredFiles(files, []) + +        # Get whitelisted extensions +        whitelist = self.get_extensions_whitelist() + +        # Filter files into allowed and blocked +        blocked = [] +        allowed = [] +        for file in files: +            if file.suffix in whitelist: +                allowed.append(file) +            else: +                blocked.append(file) + +        if blocked: +            blocked_str = ", ".join(f.suffix for f in blocked) +            log.info( +                f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", +                extra={"attachment_list": [f.path for f in files]} +            ) + +        return FilteredFiles(allowed, blocked) +      @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) -    async def send_job( -        self, -        ctx: Context, -        python_version: Literal["3.10", "3.11"], -        code: str, -        *, -        args: Optional[list[str]] = None, -        job_name: str -    ) -> Message: +    async def send_job(self, ctx: Context, job: EvalJob) -> Message:          """          Evaluate code, format it, and send the output to the corresponding channel.          Return the bot response.          """          async with ctx.typing(): -            results = await self.post_job(code, python_version, args=args) -            msg, error = self.get_results_message(results, job_name, python_version) +            result = await self.post_job(job) +            msg = result.get_message(job) +            error = result.error_message              if error:                  output, paste_link = error, None              else:                  log.trace("Formatting output...") -                output, paste_link = await self.format_output(results["stdout"]) +                output, paste_link = await self.format_output(result.stdout) -            warning_message = "" +            msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n"              # This is done to make sure the last line of output contains the error              # and the error is not manually printed by the author with a syntax error. -            if results["stdout"].rstrip().endswith("EOFError: EOF when reading a line") and results["returncode"] == 1: -                warning_message += ":warning: Note: `input` is not supported by the bot :warning:\n\n" +            if result.stdout.rstrip().endswith("EOFError: EOF when reading a line") and result.returncode == 1: +                msg += ":warning: Note: `input` is not supported by the bot :warning:\n\n" + +            # Skip output if it's empty and there are file uploads +            if result.stdout or not result.has_files: +                msg += f"\n```\n{output}\n```" -            icon = self.get_status_emoji(results) -            msg = f"{ctx.author.mention} {icon} {msg}.\n\n{warning_message}```\n{output}\n```"              if paste_link: -                msg = f"{msg}\nFull output: {paste_link}" +                msg += f"\nFull output: {paste_link}" + +            # Additional files error message after output +            if files_error := result.files_error_message: +                msg += f"\n{files_error}"              # Collect stats of job fails + successes -            if icon == ":x:": +            if result.returncode != 0:                  self.bot.stats.incr("snekbox.python.fail")              else:                  self.bot.stats.incr("snekbox.python.success") -            filter_cog = self.bot.get_cog("Filtering") -            filter_triggered = False -            if filter_cog: -                filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) -            if filter_triggered: -                response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") -            else: -                allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) -                view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args) -                response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view) -                view.message = response - -            log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") +            # Filter file extensions +            allowed, blocked = self._filter_files(ctx, result.files) +            # Also scan failed files for blocked extensions +            failed_files = [FileAttachment(name, b"") for name in result.failed_files] +            blocked.extend(self._filter_files(ctx, failed_files).blocked) +            # Add notice if any files were blocked +            if blocked: +                blocked_sorted = sorted(set(f.suffix for f in blocked)) +                # Only no extension +                if len(blocked_sorted) == 1 and blocked_sorted[0] == "": +                    blocked_msg = "Files with no extension can't be uploaded." +                # Both +                elif "" in blocked_sorted: +                    blocked_str = ", ".join(ext for ext in blocked_sorted if ext) +                    blocked_msg = ( +                        f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" +                    ) +                else: +                    blocked_str = ", ".join(blocked_sorted) +                    blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + +                msg += f"\n{Emojis.failed_file} {blocked_msg}" + +            # Split text files +            text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES] +            # Inline until budget, then upload to paste service +            # Budget is shared with stdout, so subtract what we've already used +            budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) +            budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) +            for file in text_files: +                file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" +                # Override to always allow 1 line and <= 50 chars, since this is less than a link +                if len(file_text) <= 50 and not file_text.count("\n"): +                    msg += f"\n`{file.name}`\n```\n{file_text}\n```" +                # otherwise, use budget +                else: +                    format_text, link_text = await self.format_output( +                        file_text, +                        budget_lines, +                        budget_chars, +                        line_nums=False, +                        output_default="[Empty]" +                    ) +                    # With any link, use it (don't use budget) +                    if link_text: +                        msg += f"\n`{file.name}`\n{link_text}" +                    else: +                        msg += f"\n`{file.name}`\n```\n{format_text}\n```" +                        budget_lines -= format_text.count("\n") + 1 +                        budget_chars -= len(file_text) + +            filter_cog: Filtering | None = self.bot.get_cog("Filtering") +            if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)): +                return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + +            # Upload remaining non-text files +            files = [f.to_file() for f in allowed if f not in text_files] +            allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) +            view = self.build_python_version_switcher_view(job.version, ctx, job) +            response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) +            view.message = response + +            log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}")          return response      async def continue_job(          self, ctx: Context, response: Message, job_name: str -    ) -> tuple[Optional[str], Optional[list[str]]]: +    ) -> EvalJob | None:          """          Check if the job's session should continue. -        If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. -        Otherwise return (None, None) if the job's session should be terminated. +        If the code is to be re-evaluated, return the new EvalJob. +        Otherwise, return None if the job's session should be terminated.          """          _predicate_message_edit = partial(predicate_message_edit, ctx)          _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) @@ -406,7 +459,7 @@ class Snekbox(Cog):                  # Ensure the response that's about to be edited is still the most recent.                  # This could have already been updated via a button press to switch to an alt Python version.                  if self.jobs[ctx.message.id] != response.id: -                    return None, None +                    return None                  code = await self.get_code(new_message, ctx.command)                  with contextlib.suppress(HTTPException): @@ -414,21 +467,21 @@ class Snekbox(Cog):                      await response.delete()                  if code is None: -                    return None, None +                    return None              except asyncio.TimeoutError:                  with contextlib.suppress(HTTPException):                      await ctx.message.clear_reaction(REDO_EMOJI) -                return None, None +                return None              codeblocks = await CodeblockConverter.convert(ctx, code)              if job_name == "timeit": -                return self.prepare_timeit_input(codeblocks) +                return EvalJob(self.prepare_timeit_input(codeblocks))              else: -                return "\n".join(codeblocks), None +                return EvalJob.from_code("\n".join(codeblocks)) -        return None, None +        return None      async def get_code(self, message: Message, command: Command) -> Optional[str]:          """ @@ -452,12 +505,8 @@ class Snekbox(Cog):      async def run_job(          self, -        job_name: str,          ctx: Context, -        python_version: Literal["3.10", "3.11"], -        code: str, -        *, -        args: Optional[list[str]] = None, +        job: EvalJob,      ) -> None:          """Handles checks, stats and re-evaluation of a snekbox job."""          if Roles.helpers in (role.id for role in ctx.author.roles): @@ -472,11 +521,11 @@ class Snekbox(Cog):          else:              self.bot.stats.incr("snekbox_usages.channels.topical") -        log.info(f"Received code from {ctx.author} for evaluation:\n{code}") +        log.info(f"Received code from {ctx.author} for evaluation:\n{job}")          while True:              try: -                response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name) +                response = await self.send_job(ctx, job)              except LockedResourceError:                  await ctx.send(                      f"{ctx.author.mention} You've already got a job running - " @@ -489,10 +538,10 @@ class Snekbox(Cog):              # This can happen when a button is pressed and then original code is edited and re-run.              self.jobs[ctx.message.id] = response.id -            code, args = await self.continue_job(ctx, response, job_name) -            if not code: +            job = await self.continue_job(ctx, response, job.name) +            if not job:                  break -            log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") +            log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}")      @command(name="eval", aliases=("e",), usage="[python_version] <code, ...>")      @guild_only() @@ -506,29 +555,32 @@ class Snekbox(Cog):      async def eval_command(          self,          ctx: Context, -        python_version: Optional[Literal["3.10", "3.11"]], +        python_version: PythonVersion | None,          *,          code: CodeblockConverter      ) -> None:          """          Run Python code and get the results. -        This command supports multiple lines of code, including code wrapped inside a formatted code -        block. Code can be re-evaluated by editing the original message within 10 seconds and +        This command supports multiple lines of code, including formatted code blocks. +        Code can be re-evaluated by editing the original message within 10 seconds and          clicking the reaction that subsequently appears. +        The starting working directory `/home`, is a writeable temporary file system. +        Files created, excluding names with leading underscores, will be uploaded in the response. +          If multiple codeblocks are in a message, all of them will be joined and evaluated, -        ignoring the text outside of them. +        ignoring the text outside them. -        By default your code is run on Python's 3.11 beta release, to assist with testing. If you -        run into issues related to this Python version, you can request the bot to use Python -        3.10 by specifying the `python_version` arg and setting it to `3.10`. +        By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified.          We've done our best to make this sandboxed, but do let us know if you manage to find an          issue with it!          """ +        code: list[str]          python_version = python_version or "3.11" -        await self.run_job("eval", ctx, python_version, "\n".join(code)) +        job = EvalJob.from_code("\n".join(code)).as_version(python_version) +        await self.run_job(ctx, job)      @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>")      @guild_only() @@ -542,7 +594,7 @@ class Snekbox(Cog):      async def timeit_command(          self,          ctx: Context, -        python_version: Optional[Literal["3.10", "3.11"]], +        python_version: PythonVersion | None,          *,          code: CodeblockConverter      ) -> None: @@ -556,17 +608,17 @@ class Snekbox(Cog):          If multiple formatted codeblocks are provided, the first one will be the setup code, which will          not be timed. The remaining codeblocks will be joined together and timed. -        By default your code is run on Python's 3.11 beta release, to assist with testing. If you -        run into issues related to this Python version, you can request the bot to use Python -        3.10 by specifying the `python_version` arg and setting it to `3.10`. +        By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified.          We've done our best to make this sandboxed, but do let us know if you manage to find an          issue with it!          """ +        code: list[str]          python_version = python_version or "3.11" -        code, args = self.prepare_timeit_input(code) +        args = self.prepare_timeit_input(code) +        job = EvalJob(args, version=python_version, name="timeit") -        await self.run_job("timeit", ctx, python_version, code=code, args=args) +        await self.run_job(ctx, job)  def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: @@ -577,8 +629,3 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) ->  def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool:      """Return True if the reaction REDO_EMOJI was added by the context message author on this message."""      return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI - - -async def setup(bot: Bot) -> None: -    """Load the Snekbox cog.""" -    await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py new file mode 100644 index 000000000..2f61b5924 --- /dev/null +++ b/bot/exts/utils/snekbox/_eval.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from signal import Signals +from typing import TYPE_CHECKING + +from discord.utils import escape_markdown, escape_mentions + +from bot.constants import Emojis +from bot.exts.utils.snekbox._io import FILE_COUNT_LIMIT, FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt +from bot.log import get_logger + +if TYPE_CHECKING: +    from bot.exts.utils.snekbox._cog import PythonVersion + +log = get_logger(__name__) + +SIGKILL = 9 + + +@dataclass(frozen=True) +class EvalJob: +    """Job to be evaluated by snekbox.""" + +    args: list[str] +    files: list[FileAttachment] = field(default_factory=list) +    name: str = "eval" +    version: PythonVersion = "3.11" + +    @classmethod +    def from_code(cls, code: str, path: str = "main.py") -> EvalJob: +        """Create an EvalJob from a code string.""" +        return cls( +            args=[path], +            files=[FileAttachment(path, code.encode())], +        ) + +    def as_version(self, version: PythonVersion) -> EvalJob: +        """Return a copy of the job with a different Python version.""" +        return EvalJob( +            args=self.args, +            files=self.files, +            name=self.name, +            version=version, +        ) + +    def to_dict(self) -> dict[str, list[str | dict[str, str]]]: +        """Convert the job to a dict.""" +        return { +            "args": self.args, +            "files": [file.to_dict() for file in self.files], +        } + + +@dataclass(frozen=True) +class EvalResult: +    """The result of an eval job.""" + +    stdout: str +    returncode: int | None +    files: list[FileAttachment] = field(default_factory=list) +    failed_files: list[str] = field(default_factory=list) + +    @property +    def has_output(self) -> bool: +        """True if the result has any output (stdout, files, or failed files).""" +        return bool(self.stdout.strip() or self.files or self.failed_files) + +    @property +    def has_files(self) -> bool: +        """True if the result has any files or failed files.""" +        return bool(self.files or self.failed_files) + +    @property +    def status_emoji(self) -> str: +        """Return an emoji corresponding to the status code or lack of output in result.""" +        if not self.has_output: +            return ":warning:" +        elif self.returncode == 0:  # No error +            return ":white_check_mark:" +        else:  # Exception +            return ":x:" + +    @property +    def error_message(self) -> str: +        """Return an error message corresponding to the process's return code.""" +        error = "" +        if self.returncode is None: +            error = self.stdout.strip() +        elif self.returncode == 255: +            error = "A fatal NsJail error occurred" +        return error + +    @property +    def files_error_message(self) -> str: +        """Return an error message corresponding to the failed files.""" +        if not self.failed_files: +            return "" + +        failed_files = f"({self.get_failed_files_str()})" + +        n_failed = len(self.failed_files) +        s_upload = "uploads" if n_failed > 1 else "upload" + +        msg = f"{Emojis.failed_file} {n_failed} file {s_upload} {failed_files} failed" + +        # Exceeded file count limit +        if (n_failed + len(self.files)) > FILE_COUNT_LIMIT: +            s_it = "they" if n_failed > 1 else "it" +            msg += f" as {s_it} exceeded the {FILE_COUNT_LIMIT} file limit." +        # Exceeded file size limit +        else: +            s_each_file = "each file's" if n_failed > 1 else "its file" +            msg += f" because {s_each_file} size exceeds {sizeof_fmt(FILE_SIZE_LIMIT)}." + +        return msg + +    def get_failed_files_str(self, char_max: int = 85) -> str: +        """ +        Return a string containing the names of failed files, truncated char_max. + +        Will truncate on whole file names if less than 3 characters remaining. +        """ +        names = [] +        for file in self.failed_files: +            # Only attempt to truncate name if more than 3 chars remaining +            if char_max < 3: +                names.append("...") +                break + +            if len(file) > char_max: +                names.append(file[:char_max] + "...") +                break +            char_max -= len(file) +            names.append(file) + +        text = ", ".join(names) +        # Since the file names are provided by user +        text = escape_markdown(text) +        text = escape_mentions(text) +        return text + +    def get_message(self, job: EvalJob) -> str: +        """Return a user-friendly message corresponding to the process's return code.""" +        msg = f"Your {job.version} {job.name} job" + +        if self.returncode is None: +            msg += " has failed" +        elif self.returncode == 128 + SIGKILL: +            msg += " timed out or ran out of memory" +        elif self.returncode == 255: +            msg += " has failed" +        else: +            msg += f" has completed with return code {self.returncode}" +            # Try to append signal's name if one exists +            with contextlib.suppress(ValueError): +                name = Signals(self.returncode - 128).name +                msg += f" ({name})" + +        return msg + +    @classmethod +    def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: +        """Create an EvalResult from a dict.""" +        res = cls( +            stdout=data["stdout"], +            returncode=data["returncode"], +        ) + +        files = iter(data["files"]) +        for i, file in enumerate(files): +            # Limit to FILE_COUNT_LIMIT files +            if i >= FILE_COUNT_LIMIT: +                res.failed_files.extend(file["path"] for file in files) +                break +            try: +                res.files.append(FileAttachment.from_dict(file)) +            except ValueError as e: +                log.info(f"Failed to parse file from snekbox response: {e}") +                res.failed_files.append(file["path"]) + +        return res diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py new file mode 100644 index 000000000..9be396335 --- /dev/null +++ b/bot/exts/utils/snekbox/_io.py @@ -0,0 +1,102 @@ +"""I/O File protocols for snekbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from io import BytesIO +from pathlib import PurePosixPath + +import regex +from discord import File + +# Note discord bot upload limit is 8 MiB per file, +# or 50 MiB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024 + +# Discord currently has a 10-file limit per message +FILE_COUNT_LIMIT = 10 + + +# ANSI escape sequences +RE_ANSI = regex.compile(r"\\u.*\[(.*?)m") +# Characters with a leading backslash +RE_BACKSLASH = regex.compile(r"\\.") +# Discord disallowed file name characters +RE_DISCORD_FILE_NAME_DISALLOWED = regex.compile(r"[^a-zA-Z0-9._-]+") + + +def sizeof_fmt(num: int | float, suffix: str = "B") -> str: +    """Return a human-readable file size.""" +    num = float(num) +    for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): +        if abs(num) < 1024: +            num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}" +            return f"{num_str} {unit}{suffix}" +        num /= 1024 +    num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}" +    return f"{num_str} Yi{suffix}" + + +def normalize_discord_file_name(name: str) -> str: +    """Return a normalized valid discord file name.""" +    # Discord file names only allow A-Z, a-z, 0-9, underscores, dashes, and dots +    # https://discord.com/developers/docs/reference#uploading-files +    # Server will remove any other characters, but we'll get a 400 error for \ escaped chars +    name = RE_ANSI.sub("_", name) +    name = RE_BACKSLASH.sub("_", name) +    # Replace any disallowed character with an underscore +    name = RE_DISCORD_FILE_NAME_DISALLOWED.sub("_", name) +    return name + + +@dataclass(frozen=True) +class FileAttachment: +    """File Attachment from Snekbox eval.""" + +    path: str +    content: bytes + +    def __repr__(self) -> str: +        """Return the content as a string.""" +        content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content +        return f"FileAttachment(path={self.path!r}, content={content})" + +    @property +    def suffix(self) -> str: +        """Return the file suffix.""" +        return PurePosixPath(self.path).suffix + +    @property +    def name(self) -> str: +        """Return the file name.""" +        return PurePosixPath(self.path).name + +    @classmethod +    def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: +        """Create a FileAttachment from a dict response.""" +        size = data.get("size") +        if (size and size > size_limit) or (len(data["content"]) > size_limit): +            raise ValueError("File size exceeds limit") + +        content = b64decode(data["content"]) + +        if len(content) > size_limit: +            raise ValueError("File size exceeds limit") + +        return cls(data["path"], content) + +    def to_dict(self) -> dict[str, str]: +        """Convert the attachment to a json dict.""" +        content = self.content +        if isinstance(content, str): +            content = content.encode("utf-8") + +        return { +            "path": self.path, +            "content": b64encode(content).decode("ascii"), +        } + +    def to_file(self) -> File: +        """Convert to a discord.File.""" +        name = normalize_discord_file_name(self.name) +        return File(BytesIO(self.content), filename=name) diff --git a/tests/bot/exts/utils/snekbox/__init__.py b/tests/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/utils/snekbox/__init__.py diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py new file mode 100644 index 000000000..bcf1162b8 --- /dev/null +++ b/tests/bot/exts/utils/snekbox/test_io.py @@ -0,0 +1,34 @@ +from unittest import TestCase + +# noinspection PyProtectedMember +from bot.exts.utils.snekbox import _io + + +class SnekboxIOTests(TestCase): +    # noinspection SpellCheckingInspection +    def test_normalize_file_name(self): +        """Invalid file names should be normalized.""" +        cases = [ +            # ANSI escape sequences -> underscore +            (r"\u001b[31mText", "_Text"), +            # (Multiple consecutive should be collapsed to one underscore) +            (r"a\u001b[35m\u001b[37mb", "a_b"), +            # Backslash escaped chars -> underscore +            (r"\n", "_"), +            (r"\r", "_"), +            (r"A\0\tB", "A__B"), +            # Any other disallowed chars -> underscore +            (r"\\.txt", "_.txt"), +            (r"A!@#$%^&*B, C()[]{}+=D.txt", "A_B_C_D.txt"),  # noqa: P103 +            (" ", "_"), +            # Normal file names should be unchanged +            ("legal_file-name.txt", "legal_file-name.txt"), +            ("_-.", "_-."), +        ] +        for name, expected in cases: +            with self.subTest(name=name, expected=expected): +                # Test function directly +                self.assertEqual(_io.normalize_discord_file_name(name), expected) +                # Test FileAttachment.to_file() +                obj = _io.FileAttachment(name, b"") +                self.assertEqual(obj.to_file().filename, expected) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index b1f32c210..9dcf7fd8c 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -1,5 +1,6 @@  import asyncio  import unittest +from base64 import b64encode  from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch  from discord import AllowedMentions @@ -8,7 +9,8 @@ from discord.ext import commands  from bot import constants  from bot.errors import LockedResourceError  from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import Snekbox +from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox +from bot.exts.utils.snekbox._io import FileAttachment  from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser @@ -17,34 +19,55 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Add mocked bot and cog to the instance."""          self.bot = MockBot()          self.cog = Snekbox(bot=self.bot) +        self.job = EvalJob.from_code("import random") + +    @staticmethod +    def code_args(code: str) -> tuple[EvalJob]: +        """Converts code to a tuple of arguments expected.""" +        return EvalJob.from_code(code),      async def test_post_job(self):          """Post the eval code to the URLs.snekbox_eval_api endpoint."""          resp = MagicMock() -        resp.json = AsyncMock(return_value="return") +        resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []})          context_manager = MagicMock()          context_manager.__aenter__.return_value = resp          self.bot.http_session.post.return_value = context_manager -        self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") +        job = EvalJob.from_code("import random").as_version("3.10") +        self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) + +        expected = { +            "args": ["main.py"], +            "files": [ +                { +                    "path": "main.py", +                    "content": b64encode("import random".encode()).decode() +                } +            ] +        }          self.bot.http_session.post.assert_called_with(              constants.URLs.snekbox_eval_api, -            json={"input": "import random"}, +            json=expected,              raise_for_status=True          )          resp.json.assert_awaited_once()      async def test_upload_output_reject_too_long(self):          """Reject output longer than MAX_PASTE_LENGTH.""" -        result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) +        result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1))          self.assertEqual(result, "too long to upload") -    @patch("bot.exts.utils.snekbox.send_to_paste_service") +    @patch("bot.exts.utils.snekbox._cog.send_to_paste_service")      async def test_upload_output(self, mock_paste_util):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""          await self.cog.upload_output("Test output.") -        mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) +        mock_paste_util.assert_called_once_with( +            "Test output.", +            extension="txt", +            max_length=snekbox._cog.MAX_PASTE_LENGTH +        )      async def test_codeblock_converter(self):          ctx = MockContext() @@ -76,40 +99,94 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code')          ) -        for case, setup_code, testname in cases: -            setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) -            expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) -            with self.subTest(msg=f'Test with {testname} and expected return {expected}'): +        for case, setup_code, test_name in cases: +            setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) +            expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] +            with self.subTest(msg=f'Test with {test_name} and expected return {expected}'):                  self.assertEqual(self.cog.prepare_timeit_input(case), expected) -    def test_get_results_message(self): -        """Return error and message according to the eval result.""" +    def test_eval_result_message(self): +        """EvalResult.get_message(), should return message."""          cases = ( -            ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), -            ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), -            ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) +            ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), +            ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), +            ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', ''))          )          for stdout, returncode, expected in cases: +            exp_msg, exp_err, exp_files_err = expected              with self.subTest(stdout=stdout, returncode=returncode, expected=expected): -                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') -                self.assertEqual(actual, expected) - -    @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) -    def test_get_results_message_invalid_signal(self, mock_signals: Mock): +                result = EvalResult(stdout=stdout, returncode=returncode) +                job = EvalJob([]) +                # Check all 3 message types +                msg = result.get_message(job) +                self.assertEqual(msg, exp_msg) +                error = result.error_message +                self.assertEqual(error, exp_err) +                files_error = result.files_error_message +                self.assertEqual(files_error, exp_files_err) + +    @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) +    def test_eval_result_files_error_message(self): +        """EvalResult.files_error_message, should return files error message.""" +        cases = [ +            ([], ["abc"], ( +                "1 file upload (abc) failed because its file size exceeds 8 MiB." +            )), +            ([], ["file1.bin", "f2.bin"], ( +                "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." +            )), +            (["a", "b"], ["c"], ( +                "1 file upload (c) failed as it exceeded the 2 file limit." +            )), +            (["a"], ["b", "c"], ( +                "2 file uploads (b, c) failed as they exceeded the 2 file limit." +            )), +        ] +        for files, failed_files, expected_msg in cases: +            with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): +                result = EvalResult("", 0, files, failed_files) +                msg = result.files_error_message +                self.assertIn(expected_msg, msg) + +    @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) +    def test_eval_result_files_error_str(self): +        """EvalResult.files_error_message, should return files error message.""" +        cases = [ +            # Normal +            (["x.ini"], "x.ini"), +            (["123456", "879"], "123456, 879"), +            # Break on whole name if less than 3 characters remaining +            (["12345678", "9"], "12345678, ..."), +            # Otherwise break on max chars +            (["123", "345", "67890000"], "123, 345, 6789..."), +            (["abcdefg1234567"], "abcdefg123..."), +        ] +        for failed_files, expected in cases: +            with self.subTest(failed_files=failed_files, expected=expected): +                result = EvalResult("", 0, [], failed_files) +                msg = result.get_failed_files_str(char_max=10) +                self.assertEqual(msg, expected) + +    @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) +    def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): +        result = EvalResult(stdout="", returncode=127)          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), -            ('Your 3.11 eval job has completed with return code 127', '') +            result.get_message(EvalJob([], version="3.10")), +            "Your 3.10 eval job has completed with return code 127"          ) +        self.assertEqual(result.error_message, "") +        self.assertEqual(result.files_error_message, "") -    @patch('bot.exts.utils.snekbox.Signals') -    def test_get_results_message_valid_signal(self, mock_signals: Mock): -        mock_signals.return_value.name = 'SIGTEST' +    @patch('bot.exts.utils.snekbox._eval.Signals') +    def test_eval_result_message_valid_signal(self, mock_signals: Mock): +        mock_signals.return_value.name = "SIGTEST" +        result = EvalResult(stdout="", returncode=127)          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), -            ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') +            result.get_message(EvalJob([], version="3.11")), +            "Your 3.11 eval job has completed with return code 127 (SIGTEST)"          ) -    def test_get_status_emoji(self): +    def test_eval_result_status_emoji(self):          """Return emoji according to the eval result."""          cases = (              (' ', -1, ':warning:'), @@ -118,8 +195,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for stdout, returncode, expected in cases:              with self.subTest(stdout=stdout, returncode=returncode, expected=expected): -                actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) -                self.assertEqual(actual, expected) +                result = EvalResult(stdout=stdout, returncode=returncode) +                self.assertEqual(result.status_emoji, expected)      async def test_format_output(self):          """Test output formatting.""" @@ -178,10 +255,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.command = MagicMock()          self.cog.send_job = AsyncMock(return_value=response) -        self.cog.continue_job = AsyncMock(return_value=(None, None)) +        self.cog.continue_job = AsyncMock(return_value=None)          await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) -        self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') +        job = EvalJob.from_code("MyAwesomeCode") +        self.cog.send_job.assert_called_once_with(ctx, job)          self.cog.continue_job.assert_called_once_with(ctx, response, 'eval')      async def test_eval_command_evaluate_twice(self): @@ -191,13 +269,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.command = MagicMock()          self.cog.send_job = AsyncMock(return_value=response)          self.cog.continue_job = AsyncMock() -        self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) +        self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None)          await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) -        self.cog.send_job.assert_called_with( -            ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' -        ) -        self.cog.continue_job.assert_called_with(ctx, response, 'eval') + +        expected_job = EvalJob.from_code("MyAwesomeFormattedCode") +        self.cog.send_job.assert_called_with(ctx, expected_job) +        self.cog.continue_job.assert_called_with(ctx, response, "eval")      async def test_eval_command_reject_two_eval_at_the_same_time(self):          """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -212,8 +290,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect)          with self.assertRaises(LockedResourceError):              await asyncio.gather( -                self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), -                self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), +                self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), +                self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")),              )      async def test_send_job(self): @@ -223,30 +301,31 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.send = AsyncMock()          ctx.author = MockUser(mention='@LemonLemonishBeard#0042') -        self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) -        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) -        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        eval_result = EvalResult("", 0) +        self.cog.post_job = AsyncMock(return_value=eval_result)          self.cog.format_output = AsyncMock(return_value=('[No output]', None)) +        self.cog.upload_output = AsyncMock()  # Should not be called          mocked_filter_cog = MagicMock()          mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') +        job = EvalJob.from_code('MyAwesomeCode') +        await self.cog.send_job(ctx, job),          ctx.send.assert_called_once()          self.assertEqual(              ctx.send.call_args.args[0], -            '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' +            '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed ' +            'with return code 0.\n\n```\n[No output]\n```'          )          allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions']          expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])          self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) -        self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) -        self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') +        self.cog.post_job.assert_called_once_with(job)          self.cog.format_output.assert_called_once_with('') +        self.cog.upload_output.assert_not_called()      async def test_send_job_with_paste_link(self):          """Test the send_job function with a too long output that generate a paste link.""" @@ -255,29 +334,26 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) -        self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) -        self.cog.get_status_emoji = MagicMock(return_value=':yay!:') +        eval_result = EvalResult("Way too long beard", 0) +        self.cog.post_job = AsyncMock(return_value=eval_result)          self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))          mocked_filter_cog = MagicMock()          mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') +        job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") +        await self.cog.send_job(ctx, job),          ctx.send.assert_called_once()          self.assertEqual(              ctx.send.call_args.args[0], -            '@LemonLemonishBeard#0042 :yay!: Return code 0.' +            '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job ' +            'has completed with return code 0.'              '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'          ) -        self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) -        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with( -            {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' -        ) +        self.cog.post_job.assert_called_once_with(job)          self.cog.format_output.assert_called_once_with('Way too long beard')      async def test_send_job_with_non_zero_eval(self): @@ -286,29 +362,57 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) -        self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) -        self.cog.get_status_emoji = MagicMock(return_value=':nope!:') -        self.cog.format_output = AsyncMock()  # This function isn't called + +        eval_result = EvalResult("ERROR", 127) +        self.cog.post_job = AsyncMock(return_value=eval_result) +        self.cog.upload_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock()          mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') +        job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") +        await self.cog.send_job(ctx, job),          ctx.send.assert_called_once()          self.assertEqual(              ctx.send.call_args.args[0], -            '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' +            '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.' +            '\n\n```\nERROR\n```' +        ) + +        self.cog.post_job.assert_called_once_with(job) +        self.cog.upload_output.assert_not_called() + +    async def test_send_job_with_disallowed_file_ext(self): +        """Test send_job with disallowed file extensions.""" +        ctx = MockContext() +        ctx.message = MockMessage() +        ctx.send = AsyncMock() +        ctx.author.mention = "@user#7700" + +        eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")]) +        self.cog.post_job = AsyncMock(return_value=eval_result) +        self.cog.upload_output = AsyncMock()  # This function isn't called + +        mocked_filter_cog = MagicMock() +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        self.bot.get_cog.return_value = mocked_filter_cog + +        job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") +        await self.cog.send_job(ctx, job), + +        ctx.send.assert_called_once() +        res = ctx.send.call_args.args[0] +        self.assertTrue( +            res.startswith("@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.")          ) +        self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res) -        self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) -        self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) -        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') -        self.cog.format_output.assert_not_called() +        self.cog.post_job.assert_called_once_with(job) +        self.cog.upload_output.assert_not_called() -    @patch("bot.exts.utils.snekbox.partial") +    @patch("bot.exts.utils.snekbox._cog.partial")      async def test_continue_job_does_continue(self, partial_mock):          """Test that the continue_job function does continue if required conditions are met."""          ctx = MockContext( @@ -328,19 +432,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          actual = await self.cog.continue_job(ctx, response, self.cog.eval_command)          self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) -        self.assertEqual(actual, (expected, None)) +        self.assertEqual(actual, EvalJob.from_code(expected))          self.bot.wait_for.assert_has_awaits(              (                  call(                      'message_edit', -                    check=partial_mock(snekbox.predicate_message_edit, ctx), -                    timeout=snekbox.REDO_TIMEOUT, +                    check=partial_mock(snekbox._cog.predicate_message_edit, ctx), +                    timeout=snekbox._cog.REDO_TIMEOUT,                  ), -                call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) +                call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10)              )          ) -        ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) -        ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) +        ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) +        ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI)          response.delete.assert_called_once()      async def test_continue_job_does_not_continue(self): @@ -348,8 +452,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.bot.wait_for.side_effect = asyncio.TimeoutError          actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) -        self.assertEqual(actual, (None, None)) -        ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) +        self.assertEqual(actual, None) +        ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI)      async def test_get_code(self):          """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -391,18 +495,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          for ctx_msg, new_msg, expected, testname in cases:              with self.subTest(msg=f'Messages with {testname} return {expected}'):                  ctx = MockContext(message=ctx_msg) -                actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) +                actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg)                  self.assertEqual(actual, expected)      def test_predicate_emoji_reaction(self):          """Test the predicate_emoji_reaction function."""          valid_reaction = MockReaction(message=MockMessage(id=1)) -        valid_reaction.__str__.return_value = snekbox.REDO_EMOJI +        valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI          valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))          valid_user = MockUser(id=2)          invalid_reaction_id = MockReaction(message=MockMessage(id=42)) -        invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI +        invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI          invalid_user_id = MockUser(id=42)          invalid_reaction_str = MockReaction(message=MockMessage(id=1))          invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -415,7 +519,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for reaction, user, expected, testname in cases:              with self.subTest(msg=f'Test with {testname} and expected return {expected}'): -                actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) +                actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user)                  self.assertEqual(actual, expected)  |