diff options
-rw-r--r-- | bot/exts/utils/snekbox.py | 142 |
1 files changed, 107 insertions, 35 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index b1f1ba6a8..615956637 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -6,10 +6,10 @@ import re import textwrap from functools import partial from signal import Signals -from typing import Optional, Tuple +from typing import Awaitable, Callable, Optional, Tuple from discord import HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -36,6 +36,17 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) +TIMEIT_EVAL_WRAPPER = """ +from contextlib import redirect_stdout +from io import StringIO + +with redirect_stdout(StringIO()): + del redirect_stdout, StringIO +{code} +""" + +TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop") + MAX_PASTE_LEN = 10000 # `!eval` command whitelists and blacklists. @@ -48,6 +59,8 @@ SIGKILL = 9 REEVAL_EMOJI = '\U0001f501' # :repeat: REEVAL_TIMEOUT = 30 +FormatFunc = Callable[[str], Awaitable[tuple[str, Optional[str]]]] + class Snekbox(Cog): """Safe evaluation of Python code using Snekbox.""" @@ -56,10 +69,14 @@ class Snekbox(Cog): self.bot = bot self.jobs = {} - async def post_eval(self, code: str) -> dict: + async def post_eval(self, code: str, *, args: Optional[list[str]]) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" url = URLs.snekbox_eval_api data = {"input": code} + + if args is not None: + data["args"] = args + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: return await resp.json() @@ -144,8 +161,6 @@ class Snekbox(Cog): Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters and upload the full output to a paste service. """ - log.trace("Formatting output...") - output = output.rstrip("\n") original_output = output # To be uploaded to a pasting service if needed paste_link = None @@ -185,20 +200,28 @@ class Snekbox(Cog): return output, paste_link - async def send_eval(self, ctx: Context, code: str) -> Message: + async def send_eval( + self, + ctx: Context, + code: str, + *, + args: Optional[list[str]], + format_func: FormatFunc + ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. Return the bot response. """ async with ctx.typing(): - results = await self.post_eval(code) + results = await self.post_eval(code, args=args) msg, error = self.get_results_message(results) if error: output, paste_link = error, None else: - output, paste_link = await self.format_output(results["stdout"]) + log.trace("Formatting output...") + output, paste_link = await format_func(results["stdout"]) icon = self.get_status_emoji(results) msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" @@ -247,7 +270,7 @@ class Snekbox(Cog): timeout=10 ) - code = await self.get_code(new_message) + code = await self.get_code(new_message, ctx.command) await ctx.message.clear_reaction(REEVAL_EMOJI) with contextlib.suppress(HTTPException): await response.delete() @@ -256,9 +279,9 @@ class Snekbox(Cog): await ctx.message.clear_reaction(REEVAL_EMOJI) return None - return code + return self.prepare_input(code) - async def get_code(self, message: Message) -> Optional[str]: + async def get_code(self, message: Message, command: Command) -> Optional[str]: """ Return the code from `message` to be evaluated. @@ -268,7 +291,7 @@ class Snekbox(Cog): log.trace(f"Getting context for message {message.id}.") new_ctx = await self.bot.get_context(message) - if new_ctx.command is self.eval_command: + if new_ctx.command is command: log.trace(f"Message {message.id} invokes eval command.") split = message.content.split(maxsplit=1) code = split[1] if len(split) > 1 else None @@ -278,25 +301,18 @@ class Snekbox(Cog): return code - @command(name="eval", aliases=("e",)) - @guild_only() - @redirect_output( - destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, - ping_user=False - ) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: + async def run_eval( + self, + ctx: Context, + code: str, + format_func: FormatFunc, + *, + args: Optional[list[str]] = None, + ) -> None: """ - Run Python code and get the results. + Handles checks, stats and re-evaluation of an eval. - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! + `format_func` is an async callable that takes a string (the output) and formats it to show to the user. """ if ctx.author.id in self.jobs: await ctx.send( @@ -305,10 +321,6 @@ class Snekbox(Cog): ) return - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - if Roles.helpers in (role.id for role in ctx.author.roles): self.bot.stats.incr("snekbox_usages.roles.helpers") else: @@ -325,9 +337,8 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) try: - response = await self.send_eval(ctx, code) + response = await self.send_eval(ctx, code, args=args, format_func=format_func) finally: del self.jobs[ctx.author.id] @@ -336,6 +347,67 @@ class Snekbox(Cog): break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + async def format_timeit_output(self, output: str) -> tuple[str, str]: + """ + Parses the time from the end of the output given by timeit. + + If an error happened, then it won't contain the time and instead proceed with regular formatting. + """ + split_output = output.rstrip("\n").rsplit("\n", 1) + if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]): + return split_output[1], None + + return await self.format_output(output) + + @command(name="eval", aliases=("e",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=EVAL_ROLES, + categories=NO_EVAL_CATEGORIES, + channels=NO_EVAL_CHANNELS, + ping_user=False + ) + async def eval_command(self, ctx: Context, *, code: str) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code = self.prepare_input(code) + await self.run_eval(ctx, code, format_func=self.format_output) + + @command(name="timeit", aliases=("ti",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=EVAL_ROLES, + categories=NO_EVAL_CATEGORIES, + channels=NO_EVAL_CHANNELS, + ping_user=False + ) + async def timeit_command(self, ctx: Context, *, code: str) -> str: + """ + Profile Python Code to find execution time. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code = self.prepare_input(code) + await self.run_eval( + ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")), + format_func=self.format_timeit_output, args=["-m", "timeit"] + ) + def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: """Return True if the edited message is the context message and the content was indeed modified.""" |