diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/decorators.py | 33 | ||||
| -rw-r--r-- | bot/exts/filters/filtering.py | 4 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 9 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/modlog.py | 6 | ||||
| -rw-r--r-- | bot/exts/moderation/voice_gate.py | 7 | ||||
| -rw-r--r-- | bot/exts/utils/snekbox.py | 313 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 159 | 
9 files changed, 357 insertions, 178 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index f4331264f..8971898b3 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -4,6 +4,7 @@ import types  import typing as t  from contextlib import suppress +import arrow  from discord import Member, NotFound  from discord.ext import commands  from discord.ext.commands import Cog, Context @@ -236,3 +237,35 @@ def mock_in_debug(return_value: t.Any) -> t.Callable:              return await func(*args, **kwargs)          return wrapped      return decorator + + +def ensure_future_timestamp(timestamp_arg: function.Argument) -> t.Callable: +    """ +    Ensure the timestamp argument is in the future. + +    If the condition fails, send a warning to the invoking context. + +    `timestamp_arg` is the keyword name or position index of the parameter of the decorated command +    whose value is the target timestamp. + +    This decorator must go before (below) the `command` decorator. +    """ +    def decorator(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func) +        async def wrapper(*args, **kwargs) -> t.Any: +            bound_args = function.get_bound_args(func, args, kwargs) +            target = function.get_arg_value(timestamp_arg, bound_args) + +            ctx = function.get_arg_value(1, bound_args) + +            try: +                is_future = target > arrow.utcnow() +            except TypeError: +                is_future = True +            if not is_future: +                await ctx.send(":x: Provided timestamp is in the past.") +                return + +            return await func(*args, **kwargs) +        return wrapper +    return decorator diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index f44b28125..c5a2fdb93 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -268,9 +268,9 @@ class Filtering(Cog):              # Update time when alert sent              await self.name_alerts.set(member.id, arrow.utcnow().timestamp()) -    async def filter_eval(self, result: str, msg: Message) -> bool: +    async def filter_snekbox_output(self, result: str, msg: Message) -> bool:          """ -        Filter the result of an !eval to see if it violates any of our rules, and then respond accordingly. +        Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly.          Also requires the original message, to check whether to filter and for mod logs.          Returns whether a filter was triggered or not. diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index af42ab1b8..18bed5080 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,7 +10,7 @@ from bot import constants  from bot.bot import Bot  from bot.constants import Event  from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser -from bot.decorators import respect_role_hierarchy +from bot.decorators import ensure_future_timestamp, respect_role_hierarchy  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.log import get_logger @@ -81,6 +81,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_kick(ctx, user, reason)      @command() +    @ensure_future_timestamp(timestamp_arg=3)      async def ban(          self,          ctx: Context, @@ -97,6 +98,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_ban(ctx, user, reason, expires_at=duration)      @command(aliases=("cban", "purgeban", "pban")) +    @ensure_future_timestamp(timestamp_arg=3)      async def cleanban(          self,          ctx: Context, @@ -161,6 +163,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `voicemute`?")      @command(aliases=("vmute",)) +    @ensure_future_timestamp(timestamp_arg=3)      async def voicemute(          self,          ctx: Context, @@ -180,6 +183,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Temporary infractions      @command(aliases=["mute"]) +    @ensure_future_timestamp(timestamp_arg=3)      async def tempmute(          self, ctx: Context,          user: UnambiguousMemberOrUser, @@ -213,6 +217,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_mute(ctx, user, reason, expires_at=duration)      @command(aliases=("tban",)) +    @ensure_future_timestamp(timestamp_arg=3)      async def tempban(          self,          ctx: Context, @@ -248,6 +253,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await ctx.send(":x: This command is not yet implemented. Maybe you meant to use `tempvoicemute`?")      @command(aliases=("tempvmute", "tvmute")) +    @ensure_future_timestamp(timestamp_arg=3)      async def tempvoicemute(          self,          ctx: Context, @@ -294,6 +300,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Temporary shadow infractions      @command(hidden=True, aliases=["shadowtempban", "stempban", "stban"]) +    @ensure_future_timestamp(timestamp_arg=3)      async def shadow_tempban(          self,          ctx: Context, diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index c12dff928..62d349519 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -9,6 +9,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot  from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser, allowed_strings +from bot.decorators import ensure_future_timestamp  from bot.errors import InvalidInfraction  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction.infractions import Infractions @@ -122,6 +123,7 @@ class ModManagement(commands.Cog):          await self.infraction_edit(ctx, infraction, duration, reason=reason)      @infraction_group.command(name='edit', aliases=('e',)) +    @ensure_future_timestamp(timestamp_arg=3)      async def infraction_edit(          self,          ctx: Context, diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index b91a5edba..c4a7e5081 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -11,6 +11,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot  from bot.converters import Duration, Expiry +from bot.decorators import ensure_future_timestamp  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.log import get_logger @@ -103,6 +104,7 @@ class Superstarify(InfractionScheduler, Cog):              await self.reapply_infraction(infraction, action)      @command(name="superstarify", aliases=("force_nick", "star", "starify", "superstar")) +    @ensure_future_timestamp(timestamp_arg=3)      async def superstarify(          self,          ctx: Context, diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 32ea0dc6a..796c1f021 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -11,7 +11,7 @@ from deepdiff import DeepDiff  from discord import Colour, Message, Thread  from discord.abc import GuildChannel  from discord.ext.commands import Cog, Context -from discord.utils import escape_markdown +from discord.utils import escape_markdown, format_dt, snowflake_time  from bot.bot import Bot  from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs @@ -573,6 +573,7 @@ class ModLog(Cog, name="ModLog"):                  f"**Author:** {format_user(author)}\n"                  f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"                  f"**Message ID:** `{message.id}`\n" +                f"**Sent at:** {format_dt(message.created_at)}\n"                  f"[Jump to message]({message.jump_url})\n"                  "\n"              ) @@ -581,6 +582,7 @@ class ModLog(Cog, name="ModLog"):                  f"**Author:** {format_user(author)}\n"                  f"**Channel:** #{channel.name} (`{channel.id}`)\n"                  f"**Message ID:** `{message.id}`\n" +                f"**Sent at:** {format_dt(message.created_at)}\n"                  f"[Jump to message]({message.jump_url})\n"                  "\n"              ) @@ -629,6 +631,7 @@ class ModLog(Cog, name="ModLog"):              response = (                  f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"                  f"**Message ID:** `{event.message_id}`\n" +                f"**Sent at:** {format_dt(snowflake_time(event.message_id))}\n"                  "\n"                  "This message was not cached, so the message content cannot be displayed."              ) @@ -636,6 +639,7 @@ class ModLog(Cog, name="ModLog"):              response = (                  f"**Channel:** #{channel.name} (`{channel.id}`)\n"                  f"**Message ID:** `{event.message_id}`\n" +                f"**Sent at:** {format_dt(snowflake_time(event.message_id))}\n"                  "\n"                  "This message was not cached, so the message content cannot be displayed."              ) diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py index fa66b00dd..d6b8f1239 100644 --- a/bot/exts/moderation/voice_gate.py +++ b/bot/exts/moderation/voice_gate.py @@ -10,7 +10,7 @@ from discord.ext.commands import Cog, Context, command  from bot.api import ResponseCodeError  from bot.bot import Bot -from bot.constants import Channels, Event, MODERATION_ROLES, Roles, VoiceGate as GateConf +from bot.constants import Channels, MODERATION_ROLES, Roles, VoiceGate as GateConf  from bot.decorators import has_no_roles, in_whitelist  from bot.exts.moderation.modlog import ModLog  from bot.log import get_logger @@ -191,7 +191,6 @@ class VoiceGate(Cog):                  await ctx.channel.send(ctx.author.mention, embed=embed)              return -        self.mod_log.ignore(Event.member_update, ctx.author.id)          embed = discord.Embed(              title="Voice gate passed",              description="You have been granted permission to use voice channels in Python Discord.", @@ -238,10 +237,6 @@ class VoiceGate(Cog):              log.trace(f"Excluding moderator message {message.id} from deletion in #{message.channel}.")              return -        # Ignore deleted voice verification messages -        if ctx.command is not None and ctx.command.name == "voice_verify": -            self.mod_log.ignore(Event.message_delete, message.id) -          with suppress(discord.NotFound):              await message.delete() diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index cc3a2e1d7..3c1009d2a 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -2,14 +2,14 @@ import asyncio  import contextlib  import datetime  import re -import textwrap  from functools import partial  from signal import Signals +from textwrap import dedent  from typing import Optional, Tuple  from botcore.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX  from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only  from bot.bot import Bot  from bot.constants import Categories, Channels, Roles, URLs @@ -22,17 +22,96 @@ log = get_logger(__name__)  ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}") +# The timeit command should only output the very last line, so all other output should be suppressed. +# This will be used as the setup code along with any setup code provided. +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque + +if not hasattr(sys, "_setup_finished"): +    class Writer(deque): +        '''A single-item deque wrapper for sys.stdout that will return the last line when read() is called.''' + +        def __init__(self): +            super().__init__(maxlen=1) + +        def write(self, string): +            '''Append the line to the queue if it is not empty.''' +            if string.strip(): +                self.append(string) + +        def read(self): +            '''This method will be called when print() is called. + +            The queue is emptied as we don't need the output later. +            ''' +            return self.pop() + +        def flush(self): +            '''This method will be called eventually, but we don't need to do anything here.''' +            pass + +    sys.stdout = Writer() + +    def print_last_line(): +        if sys.stdout: # If the deque is empty (i.e. an error happened), calling read() will raise an error +            # Use sys.__stdout__ here because sys.stdout is set to a Writer() instance +            print(sys.stdout.read(), file=sys.__stdout__) + +    atexit.register(print_last_line) # When exiting, print the last line (hopefully it will be the timeit output) +    sys._setup_finished = None +{setup} +""" +  MAX_PASTE_LEN = 10000 -# `!eval` command whitelists and blacklists. -NO_EVAL_CHANNELS = (Channels.python_general,) -NO_EVAL_CATEGORIES = () -EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +# The Snekbox commands' whitelists and blacklists. +NO_SNEKBOX_CHANNELS = (Channels.python_general,) +NO_SNEKBOX_CATEGORIES = () +SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)  SIGKILL = 9 -REEVAL_EMOJI = '\U0001f501'  # :repeat: -REEVAL_TIMEOUT = 30 +REDO_EMOJI = '\U0001f501'  # :repeat: +REDO_TIMEOUT = 30 + + +class CodeblockConverter(Converter): +    """Attempts to extract code from a codeblock, if provided.""" + +    @classmethod +    async def convert(cls, ctx: Context, code: str) -> list[str]: +        """ +        Extract code from the Markdown, format it, and insert it into the code template. + +        If there is any code block, ignore text outside the code block. +        Use the first code block, but prefer a fenced code block. +        If there are several fenced code blocks, concatenate only the fenced code blocks. + +        Return a list of code blocks if any, otherwise return a list with a single string of code. +        """ +        if match := list(FORMATTED_CODE_REGEX.finditer(code)): +            blocks = [block for block in match if block.group("block")] + +            if len(blocks) > 1: +                codeblocks = [block.group("code") for block in blocks] +                info = "several code blocks" +            else: +                match = match[0] if len(blocks) == 0 else blocks[0] +                code, block, lang, delim = match.group("code", "block", "lang", "delim") +                codeblocks = [dedent(code)] +                if block: +                    info = (f"'{lang}' highlighted" if lang else "plain") + " code block" +                else: +                    info = f"{delim}-enclosed inline code" +        else: +            codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] +            info = "unformatted or badly formatted code" + +        code = "\n".join(codeblocks) +        log.trace(f"Extracted {info} for evaluation:\n{code}") +        return codeblocks  class Snekbox(Cog): @@ -42,15 +121,19 @@ class Snekbox(Cog):          self.bot = bot          self.jobs = {} -    async def post_eval(self, code: str) -> dict: +    async def post_job(self, code: str, *, args: Optional[list[str]] = None) -> dict:          """Send a POST request to the Snekbox API to evaluate code and return the results."""          url = URLs.snekbox_eval_api          data = {"input": code} + +        if args is not None: +            data["args"] = args +          async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:              return await resp.json()      async def upload_output(self, output: str) -> Optional[str]: -        """Upload the eval output to a paste service and return a URL to it if successful.""" +        """Upload the job's output to a paste service and return a URL to it if successful."""          log.trace("Uploading full output to paste service...")          if len(output) > MAX_PASTE_LEN: @@ -59,49 +142,37 @@ class Snekbox(Cog):          return await send_to_paste_service(output, extension="txt")      @staticmethod -    def prepare_input(code: str) -> str: +    def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:          """ -        Extract code from the Markdown, format it, and insert it into the code template. +        Join the codeblocks into a single string, then return the code and the arguments in a tuple. -        If there is any code block, ignore text outside the code block. -        Use the first code block, but prefer a fenced code block. -        If there are several fenced code blocks, concatenate only the fenced code blocks. +        If there are multiple codeblocks, insert the first one into the wrapped setup code.          """ -        if match := list(FORMATTED_CODE_REGEX.finditer(code)): -            blocks = [block for block in match if block.group("block")] +        args = ["-m", "timeit"] +        setup = "" +        if len(codeblocks) > 1: +            setup = codeblocks.pop(0) -            if len(blocks) > 1: -                code = '\n'.join(block.group("code") for block in blocks) -                info = "several code blocks" -            else: -                match = match[0] if len(blocks) == 0 else blocks[0] -                code, block, lang, delim = match.group("code", "block", "lang", "delim") -                if block: -                    info = (f"'{lang}' highlighted" if lang else "plain") + " code block" -                else: -                    info = f"{delim}-enclosed inline code" -        else: -            code = RAW_CODE_REGEX.fullmatch(code).group("code") -            info = "unformatted or badly formatted code" +        code = "\n".join(codeblocks) -        code = textwrap.dedent(code) -        log.trace(f"Extracted {info} for evaluation:\n{code}") -        return code +        args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + +        return code, args      @staticmethod -    def get_results_message(results: dict) -> Tuple[str, str]: +    def get_results_message(results: dict, job_name: str) -> Tuple[str, str]:          """Return a user-friendly message and error corresponding to the process's return code."""          stdout, returncode = results["stdout"], results["returncode"] -        msg = f"Your eval job has completed with return code {returncode}" +        msg = f"Your {job_name} job has completed with return code {returncode}"          error = ""          if returncode is None: -            msg = "Your eval job has failed" +            msg = f"Your {job_name} job has failed"              error = stdout.strip()          elif returncode == 128 + SIGKILL: -            msg = "Your eval job timed out or ran out of memory" +            msg = f"Your {job_name} job timed out or ran out of memory"          elif returncode == 255: -            msg = "Your eval job has failed" +            msg = f"Your {job_name} job has failed"              error = "A fatal NsJail error occurred"          else:              # Try to append signal's name if one exists @@ -130,8 +201,6 @@ class Snekbox(Cog):          Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters          and upload the full output to a paste service.          """ -        log.trace("Formatting output...") -          output = output.rstrip("\n")          original_output = output  # To be uploaded to a pasting service if needed          paste_link = None @@ -171,19 +240,27 @@ class Snekbox(Cog):          return output, paste_link -    async def send_eval(self, ctx: Context, code: str) -> Message: +    async def send_job( +        self, +        ctx: Context, +        code: str, +        *, +        args: Optional[list[str]] = None, +        job_name: str +    ) -> Message:          """          Evaluate code, format it, and send the output to the corresponding channel.          Return the bot response.          """          async with ctx.typing(): -            results = await self.post_eval(code) -            msg, error = self.get_results_message(results) +            results = await self.post_job(code, args=args) +            msg, error = self.get_results_message(results, job_name)              if error:                  output, paste_link = error, None              else: +                log.trace("Formatting output...")                  output, paste_link = await self.format_output(results["stdout"])              icon = self.get_status_emoji(results) @@ -191,7 +268,7 @@ class Snekbox(Cog):              if paste_link:                  msg = f"{msg}\nFull output: {paste_link}" -            # Collect stats of eval fails + successes +            # Collect stats of job fails + successes              if icon == ":x:":                  self.bot.stats.incr("snekbox.python.fail")              else: @@ -200,7 +277,7 @@ class Snekbox(Cog):              filter_cog = self.bot.get_cog("Filtering")              filter_triggered = False              if filter_cog: -                filter_triggered = await filter_cog.filter_eval(msg, ctx.message) +                filter_triggered = await filter_cog.filter_snekbox_output(msg, ctx.message)              if filter_triggered:                  response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")              else: @@ -208,83 +285,85 @@ class Snekbox(Cog):                  response = await ctx.send(msg, allowed_mentions=allowed_mentions)              scheduling.create_task(wait_for_deletion(response, (ctx.author.id,)), event_loop=self.bot.loop) -            log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") +            log.info(f"{ctx.author}'s {job_name} job had a return code of {results['returncode']}")          return response -    async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: +    async def continue_job( +        self, ctx: Context, response: Message, command: Command +    ) -> tuple[Optional[str], Optional[list[str]]]:          """ -        Check if the eval session should continue. +        Check if the job's session should continue. -        Return the new code to evaluate or None if the eval session should be terminated. +        If the code is to be re-evaluated, return the new code, and the args if the command is the timeit command. +        Otherwise return (None, None) if the job's session should be terminated.          """ -        _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) -        _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) +        _predicate_message_edit = partial(predicate_message_edit, ctx) +        _predicate_emoji_reaction = partial(predicate_emoji_reaction, ctx)          with contextlib.suppress(NotFound):              try:                  _, new_message = await self.bot.wait_for(                      'message_edit', -                    check=_predicate_eval_message_edit, -                    timeout=REEVAL_TIMEOUT +                    check=_predicate_message_edit, +                    timeout=REDO_TIMEOUT                  ) -                await ctx.message.add_reaction(REEVAL_EMOJI) +                await ctx.message.add_reaction(REDO_EMOJI)                  await self.bot.wait_for(                      'reaction_add',                      check=_predicate_emoji_reaction,                      timeout=10                  ) -                code = await self.get_code(new_message) -                await ctx.message.clear_reaction(REEVAL_EMOJI) +                code = await self.get_code(new_message, ctx.command) +                await ctx.message.clear_reaction(REDO_EMOJI)                  with contextlib.suppress(HTTPException):                      await response.delete() +                if code is None: +                    return None, None +              except asyncio.TimeoutError: -                await ctx.message.clear_reaction(REEVAL_EMOJI) -                return None +                await ctx.message.clear_reaction(REDO_EMOJI) +                return None, None + +            codeblocks = await CodeblockConverter.convert(ctx, code) -            return code +            if command is self.timeit_command: +                return self.prepare_timeit_input(codeblocks) +            else: +                return "\n".join(codeblocks), None + +        return None, None -    async def get_code(self, message: Message) -> Optional[str]: +    async def get_code(self, message: Message, command: Command) -> Optional[str]:          """          Return the code from `message` to be evaluated. -        If the message is an invocation of the eval command, return the first argument or None if it +        If the message is an invocation of the command, return the first argument or None if it          doesn't exist. Otherwise, return the full content of the message.          """          log.trace(f"Getting context for message {message.id}.")          new_ctx = await self.bot.get_context(message) -        if new_ctx.command is self.eval_command: -            log.trace(f"Message {message.id} invokes eval command.") +        if new_ctx.command is command: +            log.trace(f"Message {message.id} invokes {command} command.")              split = message.content.split(maxsplit=1)              code = split[1] if len(split) > 1 else None          else: -            log.trace(f"Message {message.id} does not invoke eval command.") +            log.trace(f"Message {message.id} does not invoke {command} command.")              code = message.content          return code -    @command(name="eval", aliases=("e",)) -    @guild_only() -    @redirect_output( -        destination_channel=Channels.bot_commands, -        bypass_roles=EVAL_ROLES, -        categories=NO_EVAL_CATEGORIES, -        channels=NO_EVAL_CHANNELS, -        ping_user=False -    ) -    async def eval_command(self, ctx: Context, *, code: str = None) -> None: -        """ -        Run Python code and get the results. - -        This command supports multiple lines of code, including code wrapped inside a formatted code -        block. Code can be re-evaluated by editing the original message within 10 seconds and -        clicking the reaction that subsequently appears. - -        We've done our best to make this sandboxed, but do let us know if you manage to find an -        issue with it! -        """ +    async def run_job( +        self, +        job_name: str, +        ctx: Context, +        code: str, +        *, +        args: Optional[list[str]] = None, +    ) -> None: +        """Handles checks, stats and re-evaluation of a snekbox job."""          if ctx.author.id in self.jobs:              await ctx.send(                  f"{ctx.author.mention} You've already got a job running - " @@ -292,10 +371,6 @@ class Snekbox(Cog):              )              return -        if not code:  # None or empty string -            await ctx.send_help(ctx.command) -            return -          if Roles.helpers in (role.id for role in ctx.author.roles):              self.bot.stats.incr("snekbox_usages.roles.helpers")          else: @@ -312,26 +387,74 @@ class Snekbox(Cog):          while True:              self.jobs[ctx.author.id] = datetime.datetime.now() -            code = self.prepare_input(code)              try: -                response = await self.send_eval(ctx, code) +                response = await self.send_job(ctx, code, args=args, job_name=job_name)              finally:                  del self.jobs[ctx.author.id] -            code = await self.continue_eval(ctx, response) +            code, args = await self.continue_job(ctx, response, ctx.command)              if not code:                  break              log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") +    @command(name="eval", aliases=("e",)) +    @guild_only() +    @redirect_output( +        destination_channel=Channels.bot_commands, +        bypass_roles=SNEKBOX_ROLES, +        categories=NO_SNEKBOX_CATEGORIES, +        channels=NO_SNEKBOX_CHANNELS, +        ping_user=False +    ) +    async def eval_command(self, ctx: Context, *, code: CodeblockConverter) -> None: +        """ +        Run Python code and get the results. + +        This command supports multiple lines of code, including code wrapped inside a formatted code +        block. Code can be re-evaluated by editing the original message within 10 seconds and +        clicking the reaction that subsequently appears. + +        We've done our best to make this sandboxed, but do let us know if you manage to find an +        issue with it! +        """ +        await self.run_job("eval", ctx, "\n".join(code)) + +    @command(name="timeit", aliases=("ti",)) +    @guild_only() +    @redirect_output( +        destination_channel=Channels.bot_commands, +        bypass_roles=SNEKBOX_ROLES, +        categories=NO_SNEKBOX_CATEGORIES, +        channels=NO_SNEKBOX_CHANNELS, +        ping_user=False +    ) +    async def timeit_command(self, ctx: Context, *, code: CodeblockConverter) -> None: +        """ +        Profile Python Code to find execution time. + +        This command supports multiple lines of code, including code wrapped inside a formatted code +        block. Code can be re-evaluated by editing the original message within 10 seconds and +        clicking the reaction that subsequently appears. + +        If multiple formatted codeblocks are provided, the first one will be the setup code, which will +        not be timed. The remaining codeblocks will be joined together and timed. + +        We've done our best to make this sandboxed, but do let us know if you manage to find an +        issue with it! +        """ +        code, args = self.prepare_timeit_input(code) + +        await self.run_job("timeit", ctx, code=code, args=args) + -def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: +def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:      """Return True if the edited message is the context message and the content was indeed modified."""      return new_msg.id == ctx.message.id and old_msg.content != new_msg.content -def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: -    """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" -    return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI +def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: +    """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" +    return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI  def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 8bdeedd27..f68a20089 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -17,7 +17,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.bot = MockBot()          self.cog = Snekbox(bot=self.bot) -    async def test_post_eval(self): +    async def test_post_job(self):          """Post the eval code to the URLs.snekbox_eval_api endpoint."""          resp = MagicMock()          resp.json = AsyncMock(return_value="return") @@ -26,7 +26,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          context_manager.__aenter__.return_value = resp          self.bot.http_session.post.return_value = context_manager -        self.assertEqual(await self.cog.post_eval("import random"), "return") +        self.assertEqual(await self.cog.post_job("import random"), "return")          self.bot.http_session.post.assert_called_with(              constants.URLs.snekbox_eval_api,              json={"input": "import random"}, @@ -45,7 +45,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          await self.cog.upload_output("Test output.")          mock_paste_util.assert_called_once_with("Test output.", extension="txt") -    def test_prepare_input(self): +    async def test_codeblock_converter(self): +        ctx = MockContext()          cases = (              ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'),              ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), @@ -61,7 +62,24 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for case, expected, testname in cases:              with self.subTest(msg=f'Extract code from {testname}.'): -                self.assertEqual(self.cog.prepare_input(case), expected) +                self.assertEqual( +                    '\n'.join(await snekbox.CodeblockConverter.convert(ctx, case)), expected +                ) + +    def test_prepare_timeit_input(self): +        """Test the prepare_timeit_input codeblock detection.""" +        base_args = ('-m', 'timeit', '-s') +        cases = ( +            (['print("Hello World")'], '', 'single block of code'), +            (['x = 1', 'print(x)'], 'x = 1', 'two blocks of code'), +            (['x = 1', 'print(x)', 'print("Some other code.")'], 'x = 1', 'three blocks of code') +        ) + +        for case, setup_code, testname in cases: +            setup = snekbox.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) +            expected = ('\n'.join(case[1:] if setup_code else case), [*base_args, setup]) +            with self.subTest(msg=f'Test with {testname} and expected return {expected}'): +                self.assertEqual(self.cog.prepare_timeit_input(case), expected)      def test_get_results_message(self):          """Return error and message according to the eval result.""" @@ -72,13 +90,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for stdout, returncode, expected in cases:              with self.subTest(stdout=stdout, returncode=returncode, expected=expected): -                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) +                actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}, 'eval')                  self.assertEqual(actual, expected)      @patch('bot.exts.utils.snekbox.Signals', side_effect=ValueError)      def test_get_results_message_invalid_signal(self, mock_signals: Mock):          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),              ('Your eval job has completed with return code 127', '')          ) @@ -86,7 +104,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      def test_get_results_message_valid_signal(self, mock_signals: Mock):          mock_signals.return_value.name = 'SIGTEST'          self.assertEqual( -            self.cog.get_results_message({'stdout': '', 'returncode': 127}), +            self.cog.get_results_message({'stdout': '', 'returncode': 127}, 'eval'),              ('Your eval job has completed with return code 127 (SIGTEST)', '')          ) @@ -156,28 +174,29 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Test the eval command procedure."""          ctx = MockContext()          response = MockMessage() -        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') -        self.cog.send_eval = AsyncMock(return_value=response) -        self.cog.continue_eval = AsyncMock(return_value=None) +        ctx.command = MagicMock() -        await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') -        self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') -        self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') -        self.cog.continue_eval.assert_called_once_with(ctx, response) +        self.cog.send_job = AsyncMock(return_value=response) +        self.cog.continue_job = AsyncMock(return_value=(None, None)) + +        await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) +        self.cog.send_job.assert_called_once_with(ctx, 'MyAwesomeCode', args=None, job_name='eval') +        self.cog.continue_job.assert_called_once_with(ctx, response, ctx.command)      async def test_eval_command_evaluate_twice(self):          """Test the eval and re-eval command procedure."""          ctx = MockContext()          response = MockMessage() -        self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') -        self.cog.send_eval = AsyncMock(return_value=response) -        self.cog.continue_eval = AsyncMock() -        self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) - -        await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') -        self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) -        self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') -        self.cog.continue_eval.assert_called_with(ctx, response) +        ctx.command = MagicMock() +        self.cog.send_job = AsyncMock(return_value=response) +        self.cog.continue_job = AsyncMock() +        self.cog.continue_job.side_effect = (('MyAwesomeFormattedCode', None), (None, None)) + +        await self.cog.eval_command(self.cog, ctx=ctx, code=['MyAwesomeCode']) +        self.cog.send_job.assert_called_with( +            ctx, 'MyAwesomeFormattedCode', args=None, job_name='eval' +        ) +        self.cog.continue_job.assert_called_with(ctx, response, ctx.command)      async def test_eval_command_reject_two_eval_at_the_same_time(self):          """Test if the eval command rejects an eval if the author already have a running eval.""" @@ -191,29 +210,23 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"          ) -    async def test_eval_command_call_help(self): -        """Test if the eval command call the help command if no code is provided.""" -        ctx = MockContext(command="sentinel") -        await self.cog.eval_command(self.cog, ctx=ctx, code='') -        ctx.send_help.assert_called_once_with(ctx.command) - -    async def test_send_eval(self): -        """Test the send_eval function.""" +    async def test_send_job(self): +        """Test the send_job function."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author = MockUser(mention='@LemonLemonishBeard#0042') -        self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) +        self.cog.post_job = AsyncMock(return_value={'stdout': '', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:')          self.cog.format_output = AsyncMock(return_value=('[No output]', None))          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -224,28 +237,28 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])          self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) -        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}, 'eval')          self.cog.format_output.assert_called_once_with('') -    async def test_send_eval_with_paste_link(self): -        """Test the send_eval function with a too long output that generate a paste link.""" +    async def test_send_job_with_paste_link(self): +        """Test the send_job function with a too long output that generate a paste link."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.post_job = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})          self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))          self.cog.get_status_emoji = MagicMock(return_value=':yay!:')          self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -254,27 +267,27 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'          ) -        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) -        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}, 'eval')          self.cog.format_output.assert_called_once_with('Way too long beard') -    async def test_send_eval_with_non_zero_eval(self): -        """Test the send_eval function with a code returning a non-zero code.""" +    async def test_send_job_with_non_zero_eval(self): +        """Test the send_job function with a code returning a non-zero code."""          ctx = MockContext()          ctx.message = MockMessage()          ctx.send = AsyncMock()          ctx.author.mention = '@LemonLemonishBeard#0042' -        self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) +        self.cog.post_job = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})          self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))          self.cog.get_status_emoji = MagicMock(return_value=':nope!:')          self.cog.format_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_eval = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)          self.bot.get_cog.return_value = mocked_filter_cog -        await self.cog.send_eval(ctx, 'MyAwesomeCode') +        await self.cog.send_job(ctx, 'MyAwesomeCode', job_name='eval')          ctx.send.assert_called_once()          self.assertEqual( @@ -282,14 +295,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'          ) -        self.cog.post_eval.assert_called_once_with('MyAwesomeCode') +        self.cog.post_job.assert_called_once_with('MyAwesomeCode', args=None)          self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) -        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) +        self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}, 'eval')          self.cog.format_output.assert_not_called()      @patch("bot.exts.utils.snekbox.partial") -    async def test_continue_eval_does_continue(self, partial_mock): -        """Test that the continue_eval function does continue if required conditions are met.""" +    async def test_continue_job_does_continue(self, partial_mock): +        """Test that the continue_job function does continue if required conditions are met."""          ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))          response = MockMessage(delete=AsyncMock())          new_msg = MockMessage() @@ -297,30 +310,30 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          expected = "NewCode"          self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) -        actual = await self.cog.continue_eval(ctx, response) -        self.cog.get_code.assert_awaited_once_with(new_msg) -        self.assertEqual(actual, expected) +        actual = await self.cog.continue_job(ctx, response, self.cog.eval_command) +        self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) +        self.assertEqual(actual, (expected, None))          self.bot.wait_for.assert_has_awaits(              (                  call(                      'message_edit', -                    check=partial_mock(snekbox.predicate_eval_message_edit, ctx), -                    timeout=snekbox.REEVAL_TIMEOUT, +                    check=partial_mock(snekbox.predicate_message_edit, ctx), +                    timeout=snekbox.REDO_TIMEOUT,                  ), -                call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) +                call('reaction_add', check=partial_mock(snekbox.predicate_emoji_reaction, ctx), timeout=10)              )          ) -        ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) -        ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) +        ctx.message.add_reaction.assert_called_once_with(snekbox.REDO_EMOJI) +        ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)          response.delete.assert_called_once() -    async def test_continue_eval_does_not_continue(self): +    async def test_continue_job_does_not_continue(self):          ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))          self.bot.wait_for.side_effect = asyncio.TimeoutError -        actual = await self.cog.continue_eval(ctx, MockMessage()) -        self.assertEqual(actual, None) -        ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) +        actual = await self.cog.continue_job(ctx, MockMessage(), self.cog.eval_command) +        self.assertEqual(actual, (None, None)) +        ctx.message.clear_reaction.assert_called_once_with(snekbox.REDO_EMOJI)      async def test_get_code(self):          """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" @@ -343,13 +356,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):                  self.bot.get_context.return_value = MockContext(command=command)                  message = MockMessage(content=content) -                actual_code = await self.cog.get_code(message) +                actual_code = await self.cog.get_code(message, self.cog.eval_command)                  self.bot.get_context.assert_awaited_once_with(message)                  self.assertEqual(actual_code, expected_code) -    def test_predicate_eval_message_edit(self): -        """Test the predicate_eval_message_edit function.""" +    def test_predicate_message_edit(self): +        """Test the predicate_message_edit function."""          msg0 = MockMessage(id=1, content='abc')          msg1 = MockMessage(id=2, content='abcdef')          msg2 = MockMessage(id=1, content='abcdef') @@ -362,18 +375,18 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          for ctx_msg, new_msg, expected, testname in cases:              with self.subTest(msg=f'Messages with {testname} return {expected}'):                  ctx = MockContext(message=ctx_msg) -                actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) +                actual = snekbox.predicate_message_edit(ctx, ctx_msg, new_msg)                  self.assertEqual(actual, expected) -    def test_predicate_eval_emoji_reaction(self): -        """Test the predicate_eval_emoji_reaction function.""" +    def test_predicate_emoji_reaction(self): +        """Test the predicate_emoji_reaction function."""          valid_reaction = MockReaction(message=MockMessage(id=1)) -        valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI +        valid_reaction.__str__.return_value = snekbox.REDO_EMOJI          valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))          valid_user = MockUser(id=2)          invalid_reaction_id = MockReaction(message=MockMessage(id=42)) -        invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI +        invalid_reaction_id.__str__.return_value = snekbox.REDO_EMOJI          invalid_user_id = MockUser(id=42)          invalid_reaction_str = MockReaction(message=MockMessage(id=1))          invalid_reaction_str.__str__.return_value = ':longbeard:' @@ -386,7 +399,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          )          for reaction, user, expected, testname in cases:              with self.subTest(msg=f'Test with {testname} and expected return {expected}'): -                actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) +                actual = snekbox.predicate_emoji_reaction(valid_ctx, reaction, user)                  self.assertEqual(actual, expected) | 
