diff options
-rw-r--r-- | bot/exts/utils/snekbox.py | 142 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 32 |
2 files changed, 122 insertions, 52 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index ef24cbd77..bd521a4ee 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -5,10 +5,10 @@ import re import textwrap from functools import partial from signal import Signals -from typing import Optional, Tuple +from typing import Awaitable, Callable, Optional, Tuple from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User -from discord.ext.commands import Cog, Context, command, guild_only +from discord.ext.commands import Cog, Command, Context, command, guild_only from bot.bot import Bot from bot.constants import Categories, Channels, Roles, URLs @@ -36,6 +36,17 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) +TIMEIT_EVAL_WRAPPER = """ +from contextlib import redirect_stdout +from io import StringIO + +with redirect_stdout(StringIO()): + del redirect_stdout, StringIO +{code} +""" + +TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop") + MAX_PASTE_LEN = 10000 # `!eval` command whitelists and blacklists. @@ -48,6 +59,8 @@ SIGKILL = 9 REEVAL_EMOJI = '\U0001f501' # :repeat: REEVAL_TIMEOUT = 30 +FormatFunc = Callable[[str], Awaitable[tuple[str, Optional[str]]]] + class Snekbox(Cog): """Safe evaluation of Python code using Snekbox.""" @@ -56,10 +69,14 @@ class Snekbox(Cog): self.bot = bot self.jobs = {} - async def post_eval(self, code: str) -> dict: + async def post_eval(self, code: str, *, args: Optional[list[str]] = None) -> dict: """Send a POST request to the Snekbox API to evaluate code and return the results.""" url = URLs.snekbox_eval_api data = {"input": code} + + if args is not None: + data["args"] = args + async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp: return await resp.json() @@ -144,8 +161,6 @@ class Snekbox(Cog): Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters and upload the full output to a paste service. """ - log.trace("Formatting output...") - output = output.rstrip("\n") original_output = output # To be uploaded to a pasting service if needed paste_link = None @@ -185,20 +200,28 @@ class Snekbox(Cog): return output, paste_link - async def send_eval(self, ctx: Context, code: str) -> Message: + async def send_eval( + self, + ctx: Context, + code: str, + *, + args: Optional[list[str]] = None, + format_func: FormatFunc + ) -> Message: """ Evaluate code, format it, and send the output to the corresponding channel. Return the bot response. """ async with ctx.typing(): - results = await self.post_eval(code) + results = await self.post_eval(code, args=args) msg, error = self.get_results_message(results) if error: output, paste_link = error, None else: - output, paste_link = await self.format_output(results["stdout"]) + log.trace("Formatting output...") + output, paste_link = await format_func(results["stdout"]) icon = self.get_status_emoji(results) msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```" @@ -248,7 +271,7 @@ class Snekbox(Cog): timeout=10 ) - code = await self.get_code(new_message) + code = await self.get_code(new_message, ctx.command) await ctx.message.clear_reaction(REEVAL_EMOJI) with contextlib.suppress(HTTPException): await response.delete() @@ -257,9 +280,9 @@ class Snekbox(Cog): await ctx.message.clear_reaction(REEVAL_EMOJI) return None - return code + return self.prepare_input(code) - async def get_code(self, message: Message) -> Optional[str]: + async def get_code(self, message: Message, command: Command) -> Optional[str]: """ Return the code from `message` to be evaluated. @@ -269,7 +292,7 @@ class Snekbox(Cog): log.trace(f"Getting context for message {message.id}.") new_ctx = await self.bot.get_context(message) - if new_ctx.command is self.eval_command: + if new_ctx.command is command: log.trace(f"Message {message.id} invokes eval command.") split = message.content.split(maxsplit=1) code = split[1] if len(split) > 1 else None @@ -279,25 +302,18 @@ class Snekbox(Cog): return code - @command(name="eval", aliases=("e",)) - @guild_only() - @redirect_output( - destination_channel=Channels.bot_commands, - bypass_roles=EVAL_ROLES, - categories=NO_EVAL_CATEGORIES, - channels=NO_EVAL_CHANNELS, - ping_user=False - ) - async def eval_command(self, ctx: Context, *, code: str = None) -> None: + async def run_eval( + self, + ctx: Context, + code: str, + format_func: FormatFunc, + *, + args: Optional[list[str]] = None, + ) -> None: """ - Run Python code and get the results. + Handles checks, stats and re-evaluation of an eval. - This command supports multiple lines of code, including code wrapped inside a formatted code - block. Code can be re-evaluated by editing the original message within 10 seconds and - clicking the reaction that subsequently appears. - - We've done our best to make this sandboxed, but do let us know if you manage to find an - issue with it! + `format_func` is an async callable that takes a string (the output) and formats it to show to the user. """ if ctx.author.id in self.jobs: await ctx.send( @@ -306,10 +322,6 @@ class Snekbox(Cog): ) return - if not code: # None or empty string - await ctx.send_help(ctx.command) - return - if Roles.helpers in (role.id for role in ctx.author.roles): self.bot.stats.incr("snekbox_usages.roles.helpers") else: @@ -326,9 +338,8 @@ class Snekbox(Cog): while True: self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) try: - response = await self.send_eval(ctx, code) + response = await self.send_eval(ctx, code, args=args, format_func=format_func) finally: del self.jobs[ctx.author.id] @@ -337,6 +348,67 @@ class Snekbox(Cog): break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") + async def format_timeit_output(self, output: str) -> tuple[str, str]: + """ + Parses the time from the end of the output given by timeit. + + If an error happened, then it won't contain the time and instead proceed with regular formatting. + """ + split_output = output.rstrip("\n").rsplit("\n", 1) + if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]): + return split_output[1], None + + return await self.format_output(output) + + @command(name="eval", aliases=("e",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=EVAL_ROLES, + categories=NO_EVAL_CATEGORIES, + channels=NO_EVAL_CHANNELS, + ping_user=False + ) + async def eval_command(self, ctx: Context, *, code: str) -> None: + """ + Run Python code and get the results. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code = self.prepare_input(code) + await self.run_eval(ctx, code, format_func=self.format_output) + + @command(name="timeit", aliases=("ti",)) + @guild_only() + @redirect_output( + destination_channel=Channels.bot_commands, + bypass_roles=EVAL_ROLES, + categories=NO_EVAL_CATEGORIES, + channels=NO_EVAL_CHANNELS, + ping_user=False + ) + async def timeit_command(self, ctx: Context, *, code: str) -> str: + """ + Profile Python Code to find execution time. + + This command supports multiple lines of code, including code wrapped inside a formatted code + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an + issue with it! + """ + code = self.prepare_input(code) + await self.run_eval( + ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")), + format_func=self.format_timeit_output, args=["-m", "timeit"] + ) + def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: """Return True if the edited message is the context message and the content was indeed modified.""" diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 8bdeedd27..cbffaa6b0 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -162,7 +162,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') - self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_once_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_once_with(ctx, response) async def test_eval_command_evaluate_twice(self): @@ -172,11 +174,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock() - self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None) await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode') self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) - self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') + self.cog.send_eval.assert_called_with( + ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output + ) self.cog.continue_eval.assert_called_with(ctx, response) async def test_eval_command_reject_two_eval_at_the_same_time(self): @@ -191,12 +195,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" ) - async def test_eval_command_call_help(self): - """Test if the eval command call the help command if no code is provided.""" - ctx = MockContext(command="sentinel") - await self.cog.eval_command(self.cog, ctx=ctx, code='') - ctx.send_help.assert_called_once_with(ctx.command) - async def test_send_eval(self): """Test the send_eval function.""" ctx = MockContext() @@ -213,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once() self.assertEqual( @@ -224,7 +222,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict()) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) self.cog.format_output.assert_called_once_with('') @@ -245,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once() self.assertEqual( @@ -254,7 +252,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) self.cog.format_output.assert_called_once_with('Way too long beard') @@ -274,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): mocked_filter_cog.filter_eval = AsyncMock(return_value=False) self.bot.get_cog.return_value = mocked_filter_cog - await self.cog.send_eval(ctx, 'MyAwesomeCode') + await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output) ctx.send.assert_called_once() self.assertEqual( @@ -282,7 +280,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```' ) - self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None) self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) self.cog.format_output.assert_not_called() @@ -298,7 +296,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected) actual = await self.cog.continue_eval(ctx, response) - self.cog.get_code.assert_awaited_once_with(new_msg) + self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) self.assertEqual(actual, expected) self.bot.wait_for.assert_has_awaits( ( @@ -343,7 +341,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.bot.get_context.return_value = MockContext(command=command) message = MockMessage(content=content) - actual_code = await self.cog.get_code(message) + actual_code = await self.cog.get_code(message, self.cog.eval_command) self.bot.get_context.assert_awaited_once_with(message) self.assertEqual(actual_code, expected_code) |