aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ToxicKidz <[email protected]>2022-01-13 21:26:58 -0500
committerGravatar ToxicKidz <[email protected]>2022-01-13 21:26:58 -0500
commitb7e49a5fb1adb541db2cf5632a460a37ddda6d0a (patch)
tree17f0b74fb63e07398bfc3c69414fddeef8283958
parentchore: Fix merge conflicts (diff)
chore: Suppress output in the setup code, not the code that gets timed.
If multiple formatted codeblocks are passed to the command, the first one will be used as the setup code that does not get timed.
-rw-r--r--bot/exts/utils/snekbox.py70
-rw-r--r--tests/bot/exts/utils/test_snekbox.py6
2 files changed, 56 insertions, 20 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index bd521a4ee..0d8da5e56 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -2,9 +2,9 @@ import asyncio
import contextlib
import datetime
import re
-import textwrap
from functools import partial
from signal import Signals
+from textwrap import dedent
from typing import Awaitable, Callable, Optional, Tuple
from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User
@@ -36,13 +36,35 @@ RAW_CODE_REGEX = re.compile(
re.DOTALL # "." also matches newlines
)
-TIMEIT_EVAL_WRAPPER = """
-from contextlib import redirect_stdout
-from io import StringIO
+TIMEIT_SETUP_WRAPPER = """
+import atexit
+import sys
+from collections import deque
-with redirect_stdout(StringIO()):
- del redirect_stdout, StringIO
-{code}
+if not hasattr(sys, "_setup_finished"):
+ class Writer(deque):
+ def __init__(self):
+ super().__init__(maxlen=1)
+
+ def write(self, string):
+ if string.strip():
+ self.append(string)
+
+ def read(self):
+ return self.pop()
+
+ def flush(self):
+ pass
+
+ sys.stdout = Writer()
+
+ def print_last_line():
+ if sys.stdout:
+ print(sys.stdout.read(), file=sys.__stdout__)
+
+ atexit.register(print_last_line)
+ sys._setup_finished = None
+{setup}
"""
TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop")
@@ -90,34 +112,37 @@ class Snekbox(Cog):
return await send_to_paste_service(output, extension="txt")
@staticmethod
- def prepare_input(code: str) -> str:
+ def prepare_input(code: str) -> list[str]:
"""
Extract code from the Markdown, format it, and insert it into the code template.
If there is any code block, ignore text outside the code block.
Use the first code block, but prefer a fenced code block.
If there are several fenced code blocks, concatenate only the fenced code blocks.
+
+ Retrun a list of code blocks if any, otherwise return a list with a single string of code.
"""
if match := list(FORMATTED_CODE_REGEX.finditer(code)):
blocks = [block for block in match if block.group("block")]
if len(blocks) > 1:
- code = '\n'.join(block.group("code") for block in blocks)
+ codeblocks = [block.group("code") for block in blocks]
info = "several code blocks"
else:
match = match[0] if len(blocks) == 0 else blocks[0]
code, block, lang, delim = match.group("code", "block", "lang", "delim")
+ codeblocks = [dedent(code)]
if block:
info = (f"'{lang}' highlighted" if lang else "plain") + " code block"
else:
info = f"{delim}-enclosed inline code"
else:
- code = RAW_CODE_REGEX.fullmatch(code).group("code")
+ codeblocks = [dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))]
info = "unformatted or badly formatted code"
- code = textwrap.dedent(code)
+ code = "\n".join(codeblocks)
log.trace(f"Extracted {info} for evaluation:\n{code}")
- return code
+ return codeblocks
@staticmethod
def get_results_message(results: dict) -> Tuple[str, str]:
@@ -248,7 +273,7 @@ class Snekbox(Cog):
log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
return response
- async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]:
+ async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]:
"""
Check if the eval session should continue.
@@ -380,7 +405,7 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
- code = self.prepare_input(code)
+ code = "\n".join(self.prepare_input(code))
await self.run_eval(ctx, code, format_func=self.format_output)
@command(name="timeit", aliases=("ti",))
@@ -400,13 +425,24 @@ class Snekbox(Cog):
block. Code can be re-evaluated by editing the original message within 10 seconds and
clicking the reaction that subsequently appears.
+ If multiple formatted codeblocks are provided, the first one will be the setup code, which will
+ not be timed. The remaining codeblocks will be joined together and timed.
+
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
- code = self.prepare_input(code)
+ args = ["-m", "timeit"]
+ setup = ""
+ codeblocks = self.prepare_input(code)
+
+ if len(codeblocks) > 1:
+ setup = codeblocks.pop(0)
+
+ code = "\n".join(codeblocks)
+ args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
+
await self.run_eval(
- ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")),
- format_func=self.format_timeit_output, args=["-m", "timeit"]
+ ctx, code=code, format_func=self.format_timeit_output, args=args
)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index cbffaa6b0..ebab71e71 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -61,7 +61,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
for case, expected, testname in cases:
with self.subTest(msg=f'Extract code from {testname}.'):
- self.assertEqual(self.cog.prepare_input(case), expected)
+ self.assertEqual('\n'.join(self.cog.prepare_input(case)), expected)
def test_get_results_message(self):
"""Return error and message according to the eval result."""
@@ -156,7 +156,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Test the eval command procedure."""
ctx = MockContext()
response = MockMessage()
- self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.prepare_input = MagicMock(return_value=['MyAwesomeFormattedCode'])
self.cog.send_eval = AsyncMock(return_value=response)
self.cog.continue_eval = AsyncMock(return_value=None)
@@ -297,7 +297,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
actual = await self.cog.continue_eval(ctx, response)
self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command)
- self.assertEqual(actual, expected)
+ self.assertEqual(actual, [expected])
self.bot.wait_for.assert_has_awaits(
(
call(