diff options
-rw-r--r-- | bot/exts/utils/snekbox.py | 70 | ||||
-rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 6 |
2 files changed, 56 insertions, 20 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index bd521a4ee..0d8da5e56 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -2,9 +2,9 @@ import asyncio import contextlib import datetime import re -import textwrap from functools import partial from signal import Signals +from textwrap import dedent from typing import Awaitable, Callable, Optional, Tuple from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User @@ -36,13 +36,35 @@ RAW_CODE_REGEX = re.compile( re.DOTALL # "." also matches newlines ) -TIMEIT_EVAL_WRAPPER = """ -from contextlib import redirect_stdout -from io import StringIO +TIMEIT_SETUP_WRAPPER = """ +import atexit +import sys +from collections import deque -with redirect_stdout(StringIO()): - del redirect_stdout, StringIO -{code} +if not hasattr(sys, "_setup_finished"): + class Writer(deque): + def __init__(self): + super().__init__(maxlen=1) + + def write(self, string): + if string.strip(): + self.append(string) + + def read(self): + return self.pop() + + def flush(self): + pass + + sys.stdout = Writer() + + def print_last_line(): + if sys.stdout: + print(sys.stdout.read(), file=sys.__stdout__) + + atexit.register(print_last_line) + sys._setup_finished = None +{setup} """ TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop") @@ -90,34 +112,37 @@ class Snekbox(Cog): return await send_to_paste_service(output, extension="txt") @staticmethod - def prepare_input(code: str) -> str: + def prepare_input(code: str) -> list[str]: """ Extract code from the Markdown, format it, and insert it into the code template. If there is any code block, ignore text outside the code block. Use the first code block, but prefer a fenced code block. If there are several fenced code blocks, concatenate only the fenced code blocks. + + Retrun a list of code blocks if any, otherwise return a list with a single string of code. """ if match := list(FORMATTED_CODE_REGEX.finditer(code)): blocks = [block for block in match if block.group("block")] if len(blocks) > 1: - code = '\n'.join(block.group("code") for block in blocks) + codeblocks = [block.group("code") for block in blocks] info = "several code blocks" else: match = match[0] if len(blocks) == 0 else blocks[0] code, block, lang, delim = match.group("code", "block", "lang", "delim") + codeblocks = [dedent(code)] if block: info = (f"'{lang}' highlighted" if lang else "plain") + " code block" else: info = f"{delim}-enclosed inline code" else: - code = RAW_CODE_REGEX.fullmatch(code).group("code") + codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))] info = "unformatted or badly formatted code" - code = textwrap.dedent(code) + code = "\n".join(codeblocks) log.trace(f"Extracted {info} for evaluation:\n{code}") - return code + return codeblocks @staticmethod def get_results_message(results: dict) -> Tuple[str, str]: @@ -248,7 +273,7 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") return response - async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]: """ Check if the eval session should continue. @@ -380,7 +405,7 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + code = "\n".join(self.prepare_input(code)) await self.run_eval(ctx, code, format_func=self.format_output) @command(name="timeit", aliases=("ti",)) @@ -400,13 +425,24 @@ class Snekbox(Cog): block. Code can be re-evaluated by editing the original message within 10 seconds and clicking the reaction that subsequently appears. + If multiple formatted codeblocks are provided, the first one will be the setup code, which will + not be timed. The remaining codeblocks will be joined together and timed. + We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - code = self.prepare_input(code) + args = ["-m", "timeit"] + setup = "" + codeblocks = self.prepare_input(code) + + if len(codeblocks) > 1: + setup = codeblocks.pop(0) + + code = "\n".join(codeblocks) + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + await self.run_eval( - ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")), - format_func=self.format_timeit_output, args=["-m", "timeit"] + ctx, code=code, format_func=self.format_timeit_output, args=args ) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index cbffaa6b0..ebab71e71 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -61,7 +61,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): ) for case, expected, testname in cases: with self.subTest(msg=f'Extract code from {testname}.'): - self.assertEqual(self.cog.prepare_input(case), expected) + self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected) def test_get_results_message(self): """Return error and message according to the eval result.""" @@ -156,7 +156,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): """Test the eval command procedure.""" ctx = MockContext() response = MockMessage() - self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode']) self.cog.send_eval = AsyncMock(return_value=response) self.cog.continue_eval = AsyncMock(return_value=None) @@ -297,7 +297,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): actual = await self.cog.continue_eval(ctx, response) self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command) - self.assertEqual(actual, expected) + self.assertEqual(actual, [expected]) self.bot.wait_for.assert_has_awaits( ( call( |