diff options
-rw-r--r-- | bot/constants.py | 1 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/__init__.py | 12 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_cog.py (renamed from bot/exts/utils/snekbox.py) | 373 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_eval.py | 183 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_io.py | 102 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/__init__.py | 0 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/test_io.py | 34 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/test_snekbox.py (renamed from tests/bot/exts/utils/test_snekbox.py) | 266 |
8 files changed, 727 insertions, 244 deletions
diff --git a/bot/constants.py b/bot/constants.py index a4d5761be..3aacd0a16 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -584,6 +584,7 @@ class _Emojis(EnvConfig): defcon_update = "<:defconsettingsupdated:470326274082996224>" # noqa: E704 failmail = "<:failmail:633660039931887616>" + failed_file = "<:failed_file:1073298441968562226>" incident_actioned = "<:incident_actioned:714221559279255583>" incident_investigating = "<:incident_investigating:714224190928191551>" diff --git a/bot/exts/utils/snekbox/__init__.py b/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..cd1d3b059 --- /dev/null +++ b/bot/exts/utils/snekbox/__init__.py @@ -0,0 +1,12 @@ +from bot.bot import Bot +from bot.exts.utils.snekbox._cog import CodeblockConverter, Snekbox +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult + +__all__ = ("CodeblockConverter", "Snekbox", "EvalJob", "EvalResult") + + +async def setup(bot: Bot) -> None: + """Load the Snekbox cog.""" + # Defer import to reduce side effects from importing the codeblock package. + from bot.exts.utils.snekbox._cog import Snekbox + await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox/_cog.py index ddcbe01fa..b48fcf592 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import asyncio import contextlib import re from functools import partial from operator import attrgetter -from signal import Signals from textwrap import dedent -from typing import Literal, Optional, Tuple +from typing import Literal, NamedTuple, Optional, TYPE_CHECKING from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only @@ -13,14 +14,21 @@ from pydis_core.utils import interactions from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from bot.bot import Bot -from bot.constants import Channels, MODERATION_ROLES, Roles, URLs +from bot.constants import Channels, Emojis, Filter, MODERATION_ROLES, Roles, URLs from bot.decorators import redirect_output +from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME +from bot.exts.filters.antimalware import TXT_LIKE_FILES from bot.exts.help_channels._channel import is_help_forum_post +from bot.exts.utils.snekbox._eval import EvalJob, EvalResult +from bot.exts.utils.snekbox._io import FileAttachment from bot.log import get_logger from bot.utils import send_to_paste_service from bot.utils.lock import LockedResourceError, lock_arg from bot.utils.services import PasteTooLongError, PasteUploadError +if TYPE_CHECKING: + from bot.exts.filters.filtering import Filtering + log = get_logger(__name__) ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") @@ -68,17 +76,23 @@ if not hasattr(sys, "_setup_finished"): """ MAX_PASTE_LENGTH = 10_000 +# Max to display in a codeblock before sending to a paste service +# This also applies to text files +MAX_OUTPUT_BLOCK_LINES = 10 +MAX_OUTPUT_BLOCK_CHARS = 1000 # The Snekbox commands' whitelists and blacklists. NO_SNEKBOX_CHANNELS = (Channels.python_general,) NO_SNEKBOX_CATEGORIES = () SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) -SIGKILL = 9 - REDO_EMOJI = '\U0001f501' # :repeat: REDO_TIMEOUT = 30 +PythonVersion = Literal["3.10", "3.11"] + +FilteredFiles = NamedTuple("FilteredFiles", [("allowed", list[FileAttachment]), ("blocked", list[FileAttachment])]) + class CodeblockConverter(Converter): """Attempts to extract code from a codeblock, if provided.""" @@ -122,21 +136,17 @@ class PythonVersionSwitcherButton(ui.Button): def __init__( self, - job_name: str, - version_to_switch_to: Literal["3.10", "3.11"], - snekbox_cog: "Snekbox", + version_to_switch_to: PythonVersion, + snekbox_cog: Snekbox, ctx: Context, - code: str, - args: Optional[list[str]] = None + job: EvalJob, ) -> None: self.version_to_switch_to = version_to_switch_to super().__init__(label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary) self.snekbox_cog = snekbox_cog self.ctx = ctx - self.job_name = job_name - self.code = code - self.args = args + self.job = job async def callback(self, interaction: Interaction) -> None: """ @@ -149,13 +159,11 @@ class PythonVersionSwitcherButton(ui.Button): await interaction.response.defer() with contextlib.suppress(NotFound): - # Suppress this delete to cover the case where a user re-runs code and very quickly clicks the button. + # Suppress delete to cover the case where a user re-runs code and very quickly clicks the button. # The log arg on send_job will stop the actual job from running. await interaction.message.delete() - await self.snekbox_cog.run_job( - self.job_name, self.ctx, self.version_to_switch_to, self.code, args=self.args - ) + await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_switch_to)) class Snekbox(Cog): @@ -167,13 +175,12 @@ class Snekbox(Cog): def build_python_version_switcher_view( self, - job_name: str, - current_python_version: Literal["3.10", "3.11"], + current_python_version: PythonVersion, ctx: Context, - code: str, - args: Optional[list[str]] = None - ) -> None: + job: EvalJob, + ) -> interactions.ViewWithUserAndRoleCheck: """Return a view that allows the user to change what version of Python their code is run on.""" + alt_python_version: PythonVersion if current_python_version == "3.10": alt_python_version = "3.11" else: @@ -183,33 +190,25 @@ class Snekbox(Cog): allowed_users=(ctx.author.id,), allowed_roles=MODERATION_ROLES, ) - view.add_item(PythonVersionSwitcherButton(job_name, alt_python_version, self, ctx, code, args)) + view.add_item(PythonVersionSwitcherButton(alt_python_version, self, ctx, job)) view.add_item(interactions.DeleteMessageButton()) return view - async def post_job( - self, - code: str, - python_version: Literal["3.10", "3.11"], - *, - args: Optional[list[str]] = None - ) -> dict: + async def post_job(self, job: EvalJob) -> EvalResult: """Send a POST request to the Snekbox API to evaluate code and return the results.""" - if python_version == "3.10": + if job.version == "3.10": url = URLs.snekbox_eval_api else: url = URLs.snekbox_311_eval_api - data = {"input": code} - - if args is not None: - data["args"] = args + data = job.to_dict() async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: - return await resp.json() + return EvalResult.from_dict(await resp.json()) - async def upload_output(self, output: str) -> Optional[str]: + @staticmethod + async def upload_output(output: str) -> Optional[str]: """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") @@ -221,59 +220,27 @@ class Snekbox(Cog): return "unable to upload" @staticmethod - def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: + def prepare_timeit_input(codeblocks: list[str]) -> list[str]: """ - Join the codeblocks into a single string, then return the code and the arguments in a tuple. + Join the codeblocks into a single string, then return the arguments in a list. If there are multiple codeblocks, insert the first one into the wrapped setup code. """ args = ["-m", "timeit"] - setup = "" - if len(codeblocks) > 1: - setup = codeblocks.pop(0) - + setup_code = codeblocks.pop(0) if len(codeblocks) > 1 else "" code = "\n".join(codeblocks) - args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) - - return code, args - - @staticmethod - def get_results_message(results: dict, job_name: str, python_version: Literal["3.10", "3.11"]) -> Tuple[str, str]: - """Return a user-friendly message and error corresponding to the process's return code.""" - stdout, returncode = results["stdout"], results["returncode"] - msg = f"Your {python_version} {job_name} job has completed with return code {returncode}" - error = "" - - if returncode is None: - msg = f"Your {python_version} {job_name} job has failed" - error = stdout.strip() - elif returncode == 128 + SIGKILL: - msg = f"Your {python_version} {job_name} job timed out or ran out of memory" - elif returncode == 255: - msg = f"Your {python_version} {job_name} job has failed" - error = "A fatal NsJail error occurred" - else: - # Try to append signal's name if one exists - try: - name = Signals(returncode - 128).name - msg = f"{msg} ({name})" - except ValueError: - pass - - return msg, error + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup_code), code]) + return args - @staticmethod - def get_status_emoji(results: dict) -> str: - """Return an emoji corresponding to the status code or lack of output in result.""" - if not results["stdout"].strip(): # No output - return ":warning:" - elif results["returncode"] == 0: # No error - return ":white_check_mark:" - else: # Exception - return ":x:" - - async def format_output(self, output: str) -> Tuple[str, Optional[str]]: + async def format_output( + self, + output: str, + max_lines: int = MAX_OUTPUT_BLOCK_LINES, + max_chars: int = MAX_OUTPUT_BLOCK_CHARS, + line_nums: bool = True, + output_default: str = "[No output]", + ) -> tuple[str, str | None]: """ Format the output and return a tuple of the formatted output and a URL to the full output. @@ -295,96 +262,182 @@ class Snekbox(Cog): return "Code block escape attempt detected; will not output result", paste_link truncated = False - lines = output.count("\n") + lines = output.splitlines() - if lines > 0: - output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] - output = output[:11] # Limiting to only 11 lines - output = "\n".join(output) + if len(lines) > 1: + if line_nums: + lines = [f"{i:03d} | {line}" for i, line in enumerate(lines, 1)] + lines = lines[:max_lines+1] # Limiting to max+1 lines + output = "\n".join(lines) - if lines > 10: + if len(lines) > max_lines: truncated = True - if len(output) >= 1000: - output = f"{output[:1000]}\n... (truncated - too long, too many lines)" + if len(output) >= max_chars: + output = f"{output[:max_chars]}\n... (truncated - too long, too many lines)" else: output = f"{output}\n... (truncated - too many lines)" - elif len(output) >= 1000: + elif len(output) >= max_chars: truncated = True - output = f"{output[:1000]}\n... (truncated - too long)" + output = f"{output[:max_chars]}\n... (truncated - too long)" if truncated: paste_link = await self.upload_output(original_output) - output = output or "[No output]" + if output_default and not output: + output = output_default return output, paste_link + def get_extensions_whitelist(self) -> set[str]: + """Return a set of whitelisted file extensions.""" + return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES + + def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles: + """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(ctx.author, "roles") and any(role.id in Filter.role_whitelist for role in ctx.author.roles): + return FilteredFiles(files, []) + # Ignore code jam channels + if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME: + return FilteredFiles(files, []) + + # Get whitelisted extensions + whitelist = self.get_extensions_whitelist() + + # Filter files into allowed and blocked + blocked = [] + allowed = [] + for file in files: + if file.suffix in whitelist: + allowed.append(file) + else: + blocked.append(file) + + if blocked: + blocked_str = ", ".join(f.suffix for f in blocked) + log.info( + f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", + extra={"attachment_list": [f.path for f in files]} + ) + + return FilteredFiles(allowed, blocked) + @lock_arg("snekbox.send_job", "ctx", attrgetter("author.id"), raise_error=True) - async def send_job( - self, - ctx: Context, - python_version: Literal["3.10", "3.11"], - code: str, - *, - args: Optional[list[str]] = None, - job_name: str - ) -> Message: + async def send_job(self, ctx: Context, job: EvalJob) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. Return the bot response. """ async with ctx.typing(): - results = await self.post_job(code, python_version, args=args) - msg, error = self.get_results_message(results, job_name, python_version) + result = await self.post_job(job) + msg = result.get_message(job) + error = result.error_message if error: output, paste_link = error, None else: log.trace("Formatting output...") - output, paste_link = await self.format_output(results["stdout"]) + output, paste_link = await self.format_output(result.stdout) - warning_message = "" + msg = f"{ctx.author.mention} {result.status_emoji} {msg}.\n" # This is done to make sure the last line of output contains the error # and the error is not manually printed by the author with a syntax error. - if results["stdout"].rstrip().endswith("EOFError: EOF when reading a line") and results["returncode"] == 1: - warning_message += ":warning: Note: `input` is not supported by the bot :warning:\n\n" + if result.stdout.rstrip().endswith("EOFError: EOF when reading a line") and result.returncode == 1: + msg += ":warning: Note: `input` is not supported by the bot :warning:\n\n" + + # Skip output if it's empty and there are file uploads + if result.stdout or not result.has_files: + msg += f"\n```\n{output}\n```" - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n{warning_message}```\n{output}\n```" if paste_link: - msg = f"{msg}\nFull output: {paste_link}" + msg += f"\nFull output: {paste_link}" + + # Additional files error message after output + if files_error := result.files_error_message: + msg += f"\n{files_error}" # Collect stats of job fails + successes - if icon == ":x:": + if result.returncode != 0: self.bot.stats.incr("snekbox.python.fail") else: self.bot.stats.incr("snekbox.python.success") - filter_cog = self.bot.get_cog("Filtering") - filter_triggered = False - if filter_cog: - filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message) - if filter_triggered: - response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") - else: - allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) - view = self.build_python_version_switcher_view(job_name, python_version, ctx, code, args) - response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view) - view.message = response - - log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}") + # Filter file extensions + allowed, blocked = self._filter_files(ctx, result.files) + # Also scan failed files for blocked extensions + failed_files = [FileAttachment(name, b"") for name in result.failed_files] + blocked.extend(self._filter_files(ctx, failed_files).blocked) + # Add notice if any files were blocked + if blocked: + blocked_sorted = sorted(set(f.suffix for f in blocked)) + # Only no extension + if len(blocked_sorted) == 1 and blocked_sorted[0] == "": + blocked_msg = "Files with no extension can't be uploaded." + # Both + elif "" in blocked_sorted: + blocked_str = ", ".join(ext for ext in blocked_sorted if ext) + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) + else: + blocked_str = ", ".join(blocked_sorted) + blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + + msg += f"\n{Emojis.failed_file} {blocked_msg}" + + # Split text files + text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES] + # Inline until budget, then upload to paste service + # Budget is shared with stdout, so subtract what we've already used + budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) + budget_chars = MAX_OUTPUT_BLOCK_CHARS - len(output) + for file in text_files: + file_text = file.content.decode("utf-8", errors="replace") or "[Empty]" + # Override to always allow 1 line and <= 50 chars, since this is less than a link + if len(file_text) <= 50 and not file_text.count("\n"): + msg += f"\n`{file.name}`\n```\n{file_text}\n```" + # otherwise, use budget + else: + format_text, link_text = await self.format_output( + file_text, + budget_lines, + budget_chars, + line_nums=False, + output_default="[Empty]" + ) + # With any link, use it (don't use budget) + if link_text: + msg += f"\n`{file.name}`\n{link_text}" + else: + msg += f"\n`{file.name}`\n```\n{format_text}\n```" + budget_lines -= format_text.count("\n") + 1 + budget_chars -= len(file_text) + + filter_cog: Filtering | None = self.bot.get_cog("Filtering") + if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)): + return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + + # Upload remaining non-text files + files = [f.to_file() for f in allowed if f not in text_files] + allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + view = self.build_python_version_switcher_view(job.version, ctx, job) + response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) + view.message = response + + log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}") return response async def continue_job( self, ctx: Context, response: Message, job_name: str - ) -> tuple[Optional[str], Optional[list[str]]]: + ) -> EvalJob | None: """ Check if the job's session should continue. - If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. - Otherwise return (None, None) if the job's session should be terminated. + If the code is to be re-evaluated, return the new EvalJob. + Otherwise, return None if the job's session should be terminated. """ _predicate_message_edit = partial(predicate_message_edit, ctx) _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx) @@ -406,7 +459,7 @@ class Snekbox(Cog): # Ensure the response that's about to be edited is still the most recent. # This could have already been updated via a button press to switch to an alt Python version. if self.jobs[ctx.message.id] != response.id: - return None, None + return None code = await self.get_code(new_message, ctx.command) with contextlib.suppress(HTTPException): @@ -414,21 +467,21 @@ class Snekbox(Cog): await response.delete() if code is None: - return None, None + return None except asyncio.TimeoutError: with contextlib.suppress(HTTPException): await ctx.message.clear_reaction(REDO_EMOJI) - return None, None + return None codeblocks = await CodeblockConverter.convert(ctx, code) if job_name == "timeit": - return self.prepare_timeit_input(codeblocks) + return EvalJob(self.prepare_timeit_input(codeblocks)) else: - return "\n".join(codeblocks), None + return EvalJob.from_code("\n".join(codeblocks)) - return None, None + return None async def get_code(self, message: Message, command: Command) -> Optional[str]: """ @@ -452,12 +505,8 @@ class Snekbox(Cog): async def run_job( self, - job_name: str, ctx: Context, - python_version: Literal["3.10", "3.11"], - code: str, - *, - args: Optional[list[str]] = None, + job: EvalJob, ) -> None: """Handles checks, stats and re-evaluation of a snekbox job.""" if Roles.helpers in (role.id for role in ctx.author.roles): @@ -472,11 +521,11 @@ class Snekbox(Cog): else: self.bot.stats.incr("snekbox_usages.channels.topical") - log.info(f"Received code from {ctx.author} for evaluation:\n{code}") + log.info(f"Received code from {ctx.author} for evaluation:\n{job}") while True: try: - response = await self.send_job(ctx, python_version, code, args=args, job_name=job_name) + response = await self.send_job(ctx, job) except LockedResourceError: await ctx.send( f"{ctx.author.mention} You've already got a job running - " @@ -489,10 +538,10 @@ class Snekbox(Cog): # This can happen when a button is pressed and then original code is edited and re-run. self.jobs[ctx.message.id] = response.id - code, args = await self.continue_job(ctx, response, job_name) - if not code: + job = await self.continue_job(ctx, response, job.name) + if not job: break - log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + log.info(f"Re-evaluating code from message {ctx.message.id}:\n{job}") @command(name="eval", aliases=("e",), usage="[python_version] <code, ...>") @guild_only() @@ -506,29 +555,32 @@ class Snekbox(Cog): async def eval_command( self, ctx: Context, - python_version: Optional[Literal["3.10", "3.11"]], + python_version: PythonVersion | None, *, code: CodeblockConverter ) -> None: """ Run Python code and get the results. - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and + This command supports multiple lines of code, including formatted code blocks. + Code can be re-evaluated by editing the original message within 10 seconds and clicking the reaction that subsequently appears. + The starting working directory `/home`, is a writeable temporary file system. + Files created, excluding names with leading underscores, will be uploaded in the response. + If multiple codeblocks are in a message, all of them will be joined and evaluated, - ignoring the text outside of them. + ignoring the text outside them. - By default your code is run on Python's 3.11 beta release, to assist with testing. If you - run into issues related to this Python version, you can request the bot to use Python - 3.10 by specifying the `python_version` arg and setting it to `3.10`. + By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified. We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ + code: list[str] python_version = python_version or "3.11" - await self.run_job("eval", ctx, python_version, "\n".join(code)) + job = EvalJob.from_code("\n".join(code)).as_version(python_version) + await self.run_job(ctx, job) @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] <code, ...>") @guild_only() @@ -542,7 +594,7 @@ class Snekbox(Cog): async def timeit_command( self, ctx: Context, - python_version: Optional[Literal["3.10", "3.11"]], + python_version: PythonVersion | None, *, code: CodeblockConverter ) -> None: @@ -556,17 +608,17 @@ class Snekbox(Cog): If multiple formatted codeblocks are provided, the first one will be the setup code, which will not be timed. The remaining codeblocks will be joined together and timed. - By default your code is run on Python's 3.11 beta release, to assist with testing. If you - run into issues related to this Python version, you can request the bot to use Python - 3.10 by specifying the `python_version` arg and setting it to `3.10`. + By default, your code is run on Python 3.11. A `python_version` arg of `3.10` can also be specified. We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ + code: list[str] python_version = python_version or "3.11" - code, args = self.prepare_timeit_input(code) + args = self.prepare_timeit_input(code) + job = EvalJob(args, version=python_version, name="timeit") - await self.run_job("timeit", ctx, python_version, code=code, args=args) + await self.run_job(ctx, job) def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: @@ -577,8 +629,3 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI - - -async def setup(bot: Bot) -> None: - """Load the Snekbox cog.""" - await bot.add_cog(Snekbox(bot)) diff --git a/bot/exts/utils/snekbox/_eval.py b/bot/exts/utils/snekbox/_eval.py new file mode 100644 index 000000000..2f61b5924 --- /dev/null +++ b/bot/exts/utils/snekbox/_eval.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from signal import Signals +from typing import TYPE_CHECKING + +from discord.utils import escape_markdown, escape_mentions + +from bot.constants import Emojis +from bot.exts.utils.snekbox._io import FILE_COUNT_LIMIT, FILE_SIZE_LIMIT, FileAttachment, sizeof_fmt +from bot.log import get_logger + +if TYPE_CHECKING: + from bot.exts.utils.snekbox._cog import PythonVersion + +log = get_logger(__name__) + +SIGKILL = 9 + + +@dataclass(frozen=True) +class EvalJob: + """Job to be evaluated by snekbox.""" + + args: list[str] + files: list[FileAttachment] = field(default_factory=list) + name: str = "eval" + version: PythonVersion = "3.11" + + @classmethod + def from_code(cls, code: str, path: str = "main.py") -> EvalJob: + """Create an EvalJob from a code string.""" + return cls( + args=[path], + files=[FileAttachment(path, code.encode())], + ) + + def as_version(self, version: PythonVersion) -> EvalJob: + """Return a copy of the job with a different Python version.""" + return EvalJob( + args=self.args, + files=self.files, + name=self.name, + version=version, + ) + + def to_dict(self) -> dict[str, list[str | dict[str, str]]]: + """Convert the job to a dict.""" + return { + "args": self.args, + "files": [file.to_dict() for file in self.files], + } + + +@dataclass(frozen=True) +class EvalResult: + """The result of an eval job.""" + + stdout: str + returncode: int | None + files: list[FileAttachment] = field(default_factory=list) + failed_files: list[str] = field(default_factory=list) + + @property + def has_output(self) -> bool: + """True if the result has any output (stdout, files, or failed files).""" + return bool(self.stdout.strip() or self.files or self.failed_files) + + @property + def has_files(self) -> bool: + """True if the result has any files or failed files.""" + return bool(self.files or self.failed_files) + + @property + def status_emoji(self) -> str: + """Return an emoji corresponding to the status code or lack of output in result.""" + if not self.has_output: + return ":warning:" + elif self.returncode == 0: # No error + return ":white_check_mark:" + else: # Exception + return ":x:" + + @property + def error_message(self) -> str: + """Return an error message corresponding to the process's return code.""" + error = "" + if self.returncode is None: + error = self.stdout.strip() + elif self.returncode == 255: + error = "A fatal NsJail error occurred" + return error + + @property + def files_error_message(self) -> str: + """Return an error message corresponding to the failed files.""" + if not self.failed_files: + return "" + + failed_files = f"({self.get_failed_files_str()})" + + n_failed = len(self.failed_files) + s_upload = "uploads" if n_failed > 1 else "upload" + + msg = f"{Emojis.failed_file} {n_failed} file {s_upload} {failed_files} failed" + + # Exceeded file count limit + if (n_failed + len(self.files)) > FILE_COUNT_LIMIT: + s_it = "they" if n_failed > 1 else "it" + msg += f" as {s_it} exceeded the {FILE_COUNT_LIMIT} file limit." + # Exceeded file size limit + else: + s_each_file = "each file's" if n_failed > 1 else "its file" + msg += f" because {s_each_file} size exceeds {sizeof_fmt(FILE_SIZE_LIMIT)}." + + return msg + + def get_failed_files_str(self, char_max: int = 85) -> str: + """ + Return a string containing the names of failed files, truncated char_max. + + Will truncate on whole file names if less than 3 characters remaining. + """ + names = [] + for file in self.failed_files: + # Only attempt to truncate name if more than 3 chars remaining + if char_max < 3: + names.append("...") + break + + if len(file) > char_max: + names.append(file[:char_max] + "...") + break + char_max -= len(file) + names.append(file) + + text = ", ".join(names) + # Since the file names are provided by user + text = escape_markdown(text) + text = escape_mentions(text) + return text + + def get_message(self, job: EvalJob) -> str: + """Return a user-friendly message corresponding to the process's return code.""" + msg = f"Your {job.version} {job.name} job" + + if self.returncode is None: + msg += " has failed" + elif self.returncode == 128 + SIGKILL: + msg += " timed out or ran out of memory" + elif self.returncode == 255: + msg += " has failed" + else: + msg += f" has completed with return code {self.returncode}" + # Try to append signal's name if one exists + with contextlib.suppress(ValueError): + name = Signals(self.returncode - 128).name + msg += f" ({name})" + + return msg + + @classmethod + def from_dict(cls, data: dict[str, str | int | list[dict[str, str]]]) -> EvalResult: + """Create an EvalResult from a dict.""" + res = cls( + stdout=data["stdout"], + returncode=data["returncode"], + ) + + files = iter(data["files"]) + for i, file in enumerate(files): + # Limit to FILE_COUNT_LIMIT files + if i >= FILE_COUNT_LIMIT: + res.failed_files.extend(file["path"] for file in files) + break + try: + res.files.append(FileAttachment.from_dict(file)) + except ValueError as e: + log.info(f"Failed to parse file from snekbox response: {e}") + res.failed_files.append(file["path"]) + + return res diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py new file mode 100644 index 000000000..9be396335 --- /dev/null +++ b/bot/exts/utils/snekbox/_io.py @@ -0,0 +1,102 @@ +"""I/O File protocols for snekbox.""" +from __future__ import annotations + +from base64 import b64decode, b64encode +from dataclasses import dataclass +from io import BytesIO +from pathlib import PurePosixPath + +import regex +from discord import File + +# Note discord bot upload limit is 8 MiB per file, +# or 50 MiB for lvl 2 boosted servers +FILE_SIZE_LIMIT = 8 * 1024 * 1024 + +# Discord currently has a 10-file limit per message +FILE_COUNT_LIMIT = 10 + + +# ANSI escape sequences +RE_ANSI = regex.compile(r"\\u.*\[(.*?)m") +# Characters with a leading backslash +RE_BACKSLASH = regex.compile(r"\\.") +# Discord disallowed file name characters +RE_DISCORD_FILE_NAME_DISALLOWED = regex.compile(r"[^a-zA-Z0-9._-]+") + + +def sizeof_fmt(num: int | float, suffix: str = "B") -> str: + """Return a human-readable file size.""" + num = float(num) + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024: + num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}" + return f"{num_str} {unit}{suffix}" + num /= 1024 + num_str = f"{int(num)}" if num.is_integer() else f"{num:3.1f}" + return f"{num_str} Yi{suffix}" + + +def normalize_discord_file_name(name: str) -> str: + """Return a normalized valid discord file name.""" + # Discord file names only allow A-Z, a-z, 0-9, underscores, dashes, and dots + # https://discord.com/developers/docs/reference#uploading-files + # Server will remove any other characters, but we'll get a 400 error for \ escaped chars + name = RE_ANSI.sub("_", name) + name = RE_BACKSLASH.sub("_", name) + # Replace any disallowed character with an underscore + name = RE_DISCORD_FILE_NAME_DISALLOWED.sub("_", name) + return name + + +@dataclass(frozen=True) +class FileAttachment: + """File Attachment from Snekbox eval.""" + + path: str + content: bytes + + def __repr__(self) -> str: + """Return the content as a string.""" + content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content + return f"FileAttachment(path={self.path!r}, content={content})" + + @property + def suffix(self) -> str: + """Return the file suffix.""" + return PurePosixPath(self.path).suffix + + @property + def name(self) -> str: + """Return the file name.""" + return PurePosixPath(self.path).name + + @classmethod + def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: + """Create a FileAttachment from a dict response.""" + size = data.get("size") + if (size and size > size_limit) or (len(data["content"]) > size_limit): + raise ValueError("File size exceeds limit") + + content = b64decode(data["content"]) + + if len(content) > size_limit: + raise ValueError("File size exceeds limit") + + return cls(data["path"], content) + + def to_dict(self) -> dict[str, str]: + """Convert the attachment to a json dict.""" + content = self.content + if isinstance(content, str): + content = content.encode("utf-8") + + return { + "path": self.path, + "content": b64encode(content).decode("ascii"), + } + + def to_file(self) -> File: + """Convert to a discord.File.""" + name = normalize_discord_file_name(self.name) + return File(BytesIO(self.content), filename=name) diff --git a/tests/bot/exts/utils/snekbox/__init__.py b/tests/bot/exts/utils/snekbox/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/utils/snekbox/__init__.py diff --git a/tests/bot/exts/utils/snekbox/test_io.py b/tests/bot/exts/utils/snekbox/test_io.py new file mode 100644 index 000000000..bcf1162b8 --- /dev/null +++ b/tests/bot/exts/utils/snekbox/test_io.py @@ -0,0 +1,34 @@ +from unittest import TestCase + +# noinspection PyProtectedMember +from bot.exts.utils.snekbox import _io + + +class SnekboxIOTests(TestCase): + # noinspection SpellCheckingInspection + def test_normalize_file_name(self): + """Invalid file names should be normalized.""" + cases = [ + # ANSI escape sequences -> underscore + (r"\u001b[31mText", "_Text"), + # (Multiple consecutive should be collapsed to one underscore) + (r"a\u001b[35m\u001b[37mb", "a_b"), + # Backslash escaped chars -> underscore + (r"\n", "_"), + (r"\r", "_"), + (r"A\0\tB", "A__B"), + # Any other disallowed chars -> underscore + (r"\\.txt", "_.txt"), + (r"A!@#$%^&*B, C()[]{}+=D.txt", "A_B_C_D.txt"), # noqa: P103 + (" ", "_"), + # Normal file names should be unchanged + ("legal_file-name.txt", "legal_file-name.txt"), + ("_-.", "_-."), + ] + for name, expected in cases: + with self.subTest(name=name, expected=expected): + # Test function directly + self.assertEqual(_io.normalize_discord_file_name(name), expected) + # Test FileAttachment.to_file() + obj = _io.FileAttachment(name, b"") + self.assertEqual(obj.to_file().filename, expected) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index b1f32c210..9dcf7fd8c 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -1,5 +1,6 @@ import asyncio import unittest +from base64 import b64encode from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch from discord import AllowedMentions @@ -8,7 +9,8 @@ from discord.ext import commands from bot import constants from bot.errors import LockedResourceError from bot.exts.utils import snekbox -from bot.exts.utils.snekbox import Snekbox +from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox +from bot.exts.utils.snekbox._io import FileAttachment from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser @@ -17,34 +19,55 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Add mocked bot and cog to the instance.""" self.bot = MockBot() self.cog = Snekbox(bot=self.bot) + self.job = EvalJob.from_code("import random") + + @staticmethod + def code_args(code: str) -> tuple[EvalJob]: + """Converts code to a tuple of arguments expected.""" + return EvalJob.from_code(code), async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() - resp.json = AsyncMock(return_value="return") + resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) context_manager = MagicMock() context_manager.__aenter__.return_value = resp self.bot.http_session.post.return_value = context_manager - self.assertEqual(await self.cog.post_job("import random", "3.10"), "return") + job = EvalJob.from_code("import random").as_version("3.10") + self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137)) + + expected = { + "args": ["main.py"], + "files": [ + { + "path": "main.py", + "content": b64encode("import random".encode()).decode() + } + ] + } self.bot.http_session.post.assert_called_with( constants.URLs.snekbox_eval_api, - json={"input": "import random"}, + json=expected, raise_for_status=True ) resp.json.assert_awaited_once() async def test_upload_output_reject_too_long(self): """Reject output longer than MAX_PASTE_LENGTH.""" - result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LENGTH + 1)) + result = await self.cog.upload_output("-" * (snekbox._cog.MAX_PASTE_LENGTH + 1)) self.assertEqual(result, "too long to upload") - @patch("bot.exts.utils.snekbox.send_to_paste_service") + @patch("bot.exts.utils.snekbox._cog.send_to_paste_service") async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with("Test output.", extension="txt", max_length=snekbox.MAX_PASTE_LENGTH) + mock_paste_util.assert_called_once_with( + "Test output.", + extension="txt", + max_length=snekbox._cog.MAX_PASTE_LENGTH + ) async def test_codeblock_converter(self): ctx = MockContext() @@ -76,40 +99,94 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') ) - for case, setup_code, testname in cases: - setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) - expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) - with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + for case, setup_code, test_name in cases: + setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) + expected = [*base_args, setup, '\n'.join(case[1:] if setup_code else case)] + with self.subTest(msg=f'Test with {test_name} and expected return {expected}'): self.assertEqual(self.cog.prepare_timeit_input(case), expected) - def test_get_results_message(self): - """Return error and message according to the eval result.""" + def test_eval_result_message(self): + """EvalResult.get_message(), should return message.""" cases = ( - ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR')), - ('', 128 + snekbox.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '')), - ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred')) + ('ERROR', None, ('Your 3.11 eval job has failed', 'ERROR', '')), + ('', 128 + snekbox._eval.SIGKILL, ('Your 3.11 eval job timed out or ran out of memory', '', '')), + ('', 255, ('Your 3.11 eval job has failed', 'A fatal NsJail error occurred', '')) ) for stdout, returncode, expected in cases: + exp_msg, exp_err, exp_files_err = expected with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval', '3.11') - self.assertEqual(actual, expected) - - @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError) - def test_get_results_message_invalid_signal(self, mock_signals: Mock): + result = EvalResult(stdout=stdout, returncode=returncode) + job = EvalJob([]) + # Check all 3 message types + msg = result.get_message(job) + self.assertEqual(msg, exp_msg) + error = result.error_message + self.assertEqual(error, exp_err) + files_error = result.files_error_message + self.assertEqual(files_error, exp_files_err) + + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_message(self): + """EvalResult.files_error_message, should return files error message.""" + cases = [ + ([], ["abc"], ( + "1 file upload (abc) failed because its file size exceeds 8 MiB." + )), + ([], ["file1.bin", "f2.bin"], ( + "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." + )), + (["a", "b"], ["c"], ( + "1 file upload (c) failed as it exceeded the 2 file limit." + )), + (["a"], ["b", "c"], ( + "2 file uploads (b, c) failed as they exceeded the 2 file limit." + )), + ] + for files, failed_files, expected_msg in cases: + with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): + result = EvalResult("", 0, files, failed_files) + msg = result.files_error_message + self.assertIn(expected_msg, msg) + + @patch("bot.exts.utils.snekbox._eval.FILE_COUNT_LIMIT", 2) + def test_eval_result_files_error_str(self): + """EvalResult.files_error_message, should return files error message.""" + cases = [ + # Normal + (["x.ini"], "x.ini"), + (["123456", "879"], "123456, 879"), + # Break on whole name if less than 3 characters remaining + (["12345678", "9"], "12345678, ..."), + # Otherwise break on max chars + (["123", "345", "67890000"], "123, 345, 6789..."), + (["abcdefg1234567"], "abcdefg123..."), + ] + for failed_files, expected in cases: + with self.subTest(failed_files=failed_files, expected=expected): + result = EvalResult("", 0, [], failed_files) + msg = result.get_failed_files_str(char_max=10) + self.assertEqual(msg, expected) + + @patch('bot.exts.utils.snekbox._eval.Signals', side_effect=ValueError) + def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): + result = EvalResult(stdout="", returncode=127) self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127', '') + result.get_message(EvalJob([], version="3.10")), + "Your 3.10 eval job has completed with return code 127" ) + self.assertEqual(result.error_message, "") + self.assertEqual(result.files_error_message, "") - @patch('bot.exts.utils.snekbox.Signals') - def test_get_results_message_valid_signal(self, mock_signals: Mock): - mock_signals.return_value.name = 'SIGTEST' + @patch('bot.exts.utils.snekbox._eval.Signals') + def test_eval_result_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = "SIGTEST" + result = EvalResult(stdout="", returncode=127) self.assertEqual( - self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval', '3.11'), - ('Your 3.11 eval job has completed with return code 127 (SIGTEST)', '') + result.get_message(EvalJob([], version="3.11")), + "Your 3.11 eval job has completed with return code 127 (SIGTEST)" ) - def test_get_status_emoji(self): + def test_eval_result_status_emoji(self): """Return emoji according to the eval result.""" cases = ( (' ', -1, ':warning:'), @@ -118,8 +195,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): - actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) - self.assertEqual(actual, expected) + result = EvalResult(stdout=stdout, returncode=returncode) + self.assertEqual(result.status_emoji, expected) async def test_format_output(self): """Test output formatting.""" @@ -178,10 +255,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.command = MagicMock() self.cog.send_job = AsyncMock(return_value=response) - self.cog.continue_job = AsyncMock(return_value=(None, None)) + self.cog.continue_job = AsyncMock(return_value=None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) - self.cog.send_job.assert_called_once_with(ctx, '3.11', 'MyAwesomeCode', args=None, job_name='eval') + job = EvalJob.from_code("MyAwesomeCode") + self.cog.send_job.assert_called_once_with(ctx, job) self.cog.continue_job.assert_called_once_with(ctx, response, 'eval') async def test_eval_command_evaluate_twice(self): @@ -191,13 +269,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.command = MagicMock() self.cog.send_job = AsyncMock(return_value=response) self.cog.continue_job = AsyncMock() - self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + self.cog.continue_job.side_effect = (EvalJob.from_code('MyAwesomeFormattedCode'), None) await self.cog.eval_command(self.cog, ctx=ctx, python_version='3.11', code=['MyAwesomeCode']) - self.cog.send_job.assert_called_with( - ctx, '3.11', 'MyAwesomeFormattedCode', args=None, job_name='eval' - ) - self.cog.continue_job.assert_called_with(ctx, response, 'eval') + + expected_job = EvalJob.from_code("MyAwesomeFormattedCode") + self.cog.send_job.assert_called_with(ctx, expected_job) + self.cog.continue_job.assert_called_with(ctx, response, "eval") async def test_eval_command_reject_two_eval_at_the_same_time(self): """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -212,8 +290,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.post_job = AsyncMock(side_effect=delay_with_side_effect) with self.assertRaises(LockedResourceError): await asyncio.gather( - self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), - self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval'), + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), + self.cog.send_job(ctx, EvalJob.from_code("MyAwesomeCode")), ) async def test_send_job(self): @@ -223,30 +301,31 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author = MockUser(mention='@LemonLemonishBeard#0042') - self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) - self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + eval_result = EvalResult("", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + job = EvalJob.from_code('MyAwesomeCode') + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```' + '@LemonLemonishBeard#0042 :warning: Your 3.11 eval job has completed ' + 'with return code 0.\n\n```\n[No output]\n```' ) allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions'] expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) - self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval', '3.11') + self.cog.post_job.assert_called_once_with(job) self.cog.format_output.assert_called_once_with('') + self.cog.upload_output.assert_not_called() async def test_send_job_with_paste_link(self): """Test the send_job function with a too long output that generate a paste link.""" @@ -255,29 +334,26 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) - self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + eval_result = EvalResult("Way too long beard", 0) + self.cog.post_job = AsyncMock(return_value=eval_result) self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :yay!: Return code 0.' + '@LemonLemonishBeard#0042 :white_check_mark: Your 3.11 eval job ' + 'has completed with return code 0.' '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) - self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) - self.cog.get_results_message.assert_called_once_with( - {'stdout': 'Way too long beard', 'returncode': 0}, 'eval', '3.11' - ) + self.cog.post_job.assert_called_once_with(job) self.cog.format_output.assert_called_once_with('Way too long beard') async def test_send_job_with_non_zero_eval(self): @@ -286,29 +362,57 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ctx.message = MockMessage() ctx.send = AsyncMock() ctx.author.mention = '@LemonLemonishBeard#0042' - self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) - self.cog.get_status_emoji = MagicMock(return_value=':nope!:') - self.cog.format_output = AsyncMock() # This function isn't called + + eval_result = EvalResult("ERROR", 127) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_job(ctx, '3.11', 'MyAwesomeCode', job_name='eval') + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), ctx.send.assert_called_once() self.assertEqual( ctx.send.call_args.args[0], - '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' + '@LemonLemonishBeard#0042 :x: Your 3.11 eval job has completed with return code 127.' + '\n\n```\nERROR\n```' + ) + + self.cog.post_job.assert_called_once_with(job) + self.cog.upload_output.assert_not_called() + + async def test_send_job_with_disallowed_file_ext(self): + """Test send_job with disallowed file extensions.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = "@user#7700" + + eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")]) + self.cog.post_job = AsyncMock(return_value=eval_result) + self.cog.upload_output = AsyncMock() # This function isn't called + + mocked_filter_cog = MagicMock() + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + self.bot.get_cog.return_value = mocked_filter_cog + + job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") + await self.cog.send_job(ctx, job), + + ctx.send.assert_called_once() + res = ctx.send.call_args.args[0] + self.assertTrue( + res.startswith("@user#7700 :white_check_mark: Your 3.11 eval job has completed with return code 0.") ) + self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res) - self.cog.post_job.assert_called_once_with('MyAwesomeCode', '3.11', args=None) - self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) - self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval', '3.11') - self.cog.format_output.assert_not_called() + self.cog.post_job.assert_called_once_with(job) + self.cog.upload_output.assert_not_called() - @patch("bot.exts.utils.snekbox.partial") + @patch("bot.exts.utils.snekbox._cog.partial") async def test_continue_job_does_continue(self, partial_mock): """Test that the continue_job function does continue if required conditions are met.""" ctx = MockContext( @@ -328,19 +432,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, (expected, None)) + self.assertEqual(actual, EvalJob.from_code(expected)) self.bot.wait_for.assert_has_awaits( ( call( 'message_edit', - check=partial_mock(snekbox.predicate_message_edit, ctx), - timeout=snekbox.REDO_TIMEOUT, + check=partial_mock(snekbox._cog.predicate_message_edit, ctx), + timeout=snekbox._cog.REDO_TIMEOUT, ), - call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10) + call('reaction_add', check=partial_mock(snekbox._cog.predicate_emoji_reaction, ctx), timeout=10) ) ) - ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + ctx.message.add_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) response.delete.assert_called_once() async def test_continue_job_does_not_continue(self): @@ -348,8 +452,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.wait_for.side_effect = asyncio.TimeoutError actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) - self.assertEqual(actual, (None, None)) - ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI) + self.assertEqual(actual, None) + ctx.message.clear_reaction.assert_called_once_with(snekbox._cog.REDO_EMOJI) async def test_get_code(self): """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -391,18 +495,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): for ctx_msg, new_msg, expected, testname in cases: with self.subTest(msg=f'Messages with {testname} return {expected}'): ctx = MockContext(message=ctx_msg) - actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg) + actual = snekbox._cog.predicate_message_edit(ctx, ctx_msg, new_msg) self.assertEqual(actual, expected) def test_predicate_emoji_reaction(self): """Test the predicate_emoji_reaction function.""" valid_reaction = MockReaction(message=MockMessage(id=1)) - valid_reaction.__str__.return_value = snekbox.REDO_EMOJI + valid_reaction.__str__.return_value = snekbox._cog.REDO_EMOJI valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) valid_user = MockUser(id=2) invalid_reaction_id = MockReaction(message=MockMessage(id=42)) - invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI + invalid_reaction_id.__str__.return_value = snekbox._cog.REDO_EMOJI invalid_user_id = MockUser(id=42) invalid_reaction_str = MockReaction(message=MockMessage(id=1)) invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -415,7 +519,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for reaction, user, expected, testname in cases: with self.subTest(msg=f'Test with {testname} and expected return {expected}'): - actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user) + actual = snekbox._cog.predicate_emoji_reaction(valid_ctx, reaction, user) self.assertEqual(actual, expected) |