aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox.py142
-rw-r--r--tests/bot/exts/utils/test_snekbox.py32
2 files changed, 122 insertions, 52 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index ef24cbd77..bd521a4ee 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -5,10 +5,10 @@ import re
import textwrap
from functools import partial
from signal import Signals
-from typing import Optional, Tuple
+from typing import Awaitable, Callable, Optional, Tuple
from discord import AllowedMentions, HTTPException, Message, NotFound, Reaction, User
-from discord.ext.commands import Cog, Context, command, guild_only
+from discord.ext.commands import Cog, Command, Context, command, guild_only
from bot.bot import Bot
from bot.constants import Categories, Channels, Roles, URLs
@@ -36,6 +36,17 @@ RAW_CODE_REGEX = re.compile(
re.DOTALL # "." also matches newlines
)
+TIMEIT_EVAL_WRAPPER = """
+from contextlib import redirect_stdout
+from io import StringIO
+
+with redirect_stdout(StringIO()):
+ del redirect_stdout, StringIO
+{code}
+"""
+
+TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop")
+
MAX_PASTE_LEN = 10000
# `!eval` command whitelists and blacklists.
@@ -48,6 +59,8 @@ SIGKILL = 9
REEVAL_EMOJI = '\U0001f501' # :repeat:
REEVAL_TIMEOUT = 30
+FormatFunc = Callable[[str], Awaitable[tuple[str, Optional[str]]]]
+
class Snekbox(Cog):
"""Safe evaluation of Python code using Snekbox."""
@@ -56,10 +69,14 @@ class Snekbox(Cog):
self.bot = bot
self.jobs = {}
- async def post_eval(self, code: str) -> dict:
+ async def post_eval(self, code: str, *, args: Optional[list[str]] = None) -> dict:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
url = URLs.snekbox_eval_api
data = {"input": code}
+
+ if args is not None:
+ data["args"] = args
+
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
return await resp.json()
@@ -144,8 +161,6 @@ class Snekbox(Cog):
Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters
and upload the full output to a paste service.
"""
- log.trace("Formatting output...")
-
output = output.rstrip("\n")
original_output = output # To be uploaded to a pasting service if needed
paste_link = None
@@ -185,20 +200,28 @@ class Snekbox(Cog):
return output, paste_link
- async def send_eval(self, ctx: Context, code: str) -> Message:
+ async def send_eval(
+ self,
+ ctx: Context,
+ code: str,
+ *,
+ args: Optional[list[str]] = None,
+ format_func: FormatFunc
+ ) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_eval(code)
+ results = await self.post_eval(code, args=args)
msg, error = self.get_results_message(results)
if error:
output, paste_link = error, None
else:
- output, paste_link = await self.format_output(results["stdout"])
+ log.trace("Formatting output...")
+ output, paste_link = await format_func(results["stdout"])
icon = self.get_status_emoji(results)
msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```"
@@ -248,7 +271,7 @@ class Snekbox(Cog):
timeout=10
)
- code = await self.get_code(new_message)
+ code = await self.get_code(new_message, ctx.command)
await ctx.message.clear_reaction(REEVAL_EMOJI)
with contextlib.suppress(HTTPException):
await response.delete()
@@ -257,9 +280,9 @@ class Snekbox(Cog):
await ctx.message.clear_reaction(REEVAL_EMOJI)
return None
- return code
+ return self.prepare_input(code)
- async def get_code(self, message: Message) -> Optional[str]:
+ async def get_code(self, message: Message, command: Command) -> Optional[str]:
"""
Return the code from `message` to be evaluated.
@@ -269,7 +292,7 @@ class Snekbox(Cog):
log.trace(f"Getting context for message {message.id}.")
new_ctx = await self.bot.get_context(message)
- if new_ctx.command is self.eval_command:
+ if new_ctx.command is command:
log.trace(f"Message {message.id} invokes eval command.")
split = message.content.split(maxsplit=1)
code = split[1] if len(split) > 1 else None
@@ -279,25 +302,18 @@ class Snekbox(Cog):
return code
- @command(name="eval", aliases=("e",))
- @guild_only()
- @redirect_output(
- destination_channel=Channels.bot_commands,
- bypass_roles=EVAL_ROLES,
- categories=NO_EVAL_CATEGORIES,
- channels=NO_EVAL_CHANNELS,
- ping_user=False
- )
- async def eval_command(self, ctx: Context, *, code: str = None) -> None:
+ async def run_eval(
+ self,
+ ctx: Context,
+ code: str,
+ format_func: FormatFunc,
+ *,
+ args: Optional[list[str]] = None,
+ ) -> None:
"""
- Run Python code and get the results.
+ Handles checks, stats and re-evaluation of an eval.
- This command supports multiple lines of code, including code wrapped inside a formatted code
- block. Code can be re-evaluated by editing the original message within 10 seconds and
- clicking the reaction that subsequently appears.
-
- We've done our best to make this sandboxed, but do let us know if you manage to find an
- issue with it!
+ `format_func` is an async callable that takes a string (the output) and formats it to show to the user.
"""
if ctx.author.id in self.jobs:
await ctx.send(
@@ -306,10 +322,6 @@ class Snekbox(Cog):
)
return
- if not code: # None or empty string
- await ctx.send_help(ctx.command)
- return
-
if Roles.helpers in (role.id for role in ctx.author.roles):
self.bot.stats.incr("snekbox_usages.roles.helpers")
else:
@@ -326,9 +338,8 @@ class Snekbox(Cog):
while True:
self.jobs[ctx.author.id] = datetime.datetime.now()
- code = self.prepare_input(code)
try:
- response = await self.send_eval(ctx, code)
+ response = await self.send_eval(ctx, code, args=args, format_func=format_func)
finally:
del self.jobs[ctx.author.id]
@@ -337,6 +348,67 @@ class Snekbox(Cog):
break
log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
+ async def format_timeit_output(self, output: str) -> tuple[str, str]:
+ """
+ Parses the time from the end of the output given by timeit.
+
+ If an error happened, then it won't contain the time and instead proceed with regular formatting.
+ """
+ split_output = output.rstrip("\n").rsplit("\n", 1)
+ if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]):
+ return split_output[1], None
+
+ return await self.format_output(output)
+
+ @command(name="eval", aliases=("e",))
+ @guild_only()
+ @redirect_output(
+ destination_channel=Channels.bot_commands,
+ bypass_roles=EVAL_ROLES,
+ categories=NO_EVAL_CATEGORIES,
+ channels=NO_EVAL_CHANNELS,
+ ping_user=False
+ )
+ async def eval_command(self, ctx: Context, *, code: str) -> None:
+ """
+ Run Python code and get the results.
+
+ This command supports multiple lines of code, including code wrapped inside a formatted code
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
+ issue with it!
+ """
+ code = self.prepare_input(code)
+ await self.run_eval(ctx, code, format_func=self.format_output)
+
+ @command(name="timeit", aliases=("ti",))
+ @guild_only()
+ @redirect_output(
+ destination_channel=Channels.bot_commands,
+ bypass_roles=EVAL_ROLES,
+ categories=NO_EVAL_CATEGORIES,
+ channels=NO_EVAL_CHANNELS,
+ ping_user=False
+ )
+ async def timeit_command(self, ctx: Context, *, code: str) -> str:
+ """
+ Profile Python Code to find execution time.
+
+ This command supports multiple lines of code, including code wrapped inside a formatted code
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
+ issue with it!
+ """
+ code = self.prepare_input(code)
+ await self.run_eval(
+ ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")),
+ format_func=self.format_timeit_output, args=["-m", "timeit"]
+ )
+
def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
"""Return True if the edited message is the context message and the content was indeed modified."""
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 8bdeedd27..cbffaa6b0 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -162,7 +162,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
- self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.send_eval.assert_called_once_with(
+ ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output
+ )
self.cog.continue_eval.assert_called_once_with(ctx, response)
async def test_eval_command_evaluate_twice(self):
@@ -172,11 +174,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
self.cog.send_eval = AsyncMock(return_value=response)
self.cog.continue_eval = AsyncMock()
- self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
+ self.cog.continue_eval.side_effect = ('MyAwesomeFormattedCode', None)
await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
- self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.send_eval.assert_called_with(
+ ctx, 'MyAwesomeFormattedCode', args=None, format_func=self.cog.format_output
+ )
self.cog.continue_eval.assert_called_with(ctx, response)
async def test_eval_command_reject_two_eval_at_the_same_time(self):
@@ -191,12 +195,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
)
- async def test_eval_command_call_help(self):
- """Test if the eval command call the help command if no code is provided."""
- ctx = MockContext(command="sentinel")
- await self.cog.eval_command(self.cog, ctx=ctx, code='')
- ctx.send_help.assert_called_once_with(ctx.command)
-
async def test_send_eval(self):
"""Test the send_eval function."""
ctx = MockContext()
@@ -213,7 +211,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
ctx.send.assert_called_once()
self.assertEqual(
@@ -224,7 +222,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
self.cog.format_output.assert_called_once_with('')
@@ -245,7 +243,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
ctx.send.assert_called_once()
self.assertEqual(
@@ -254,7 +252,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
self.cog.format_output.assert_called_once_with('Way too long beard')
@@ -274,7 +272,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mocked_filter_cog.filter_eval = AsyncMock(return_value=False)
self.bot.get_cog.return_value = mocked_filter_cog
- await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ await self.cog.send_eval(ctx, 'MyAwesomeCode', format_func=self.cog.format_output)
ctx.send.assert_called_once()
self.assertEqual(
@@ -282,7 +280,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
'@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
)
- self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode', args=None)
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
self.cog.format_output.assert_not_called()
@@ -298,7 +296,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.get_code = create_autospec(self.cog.get_code, spec_set=True, return_value=expected)
actual = await self.cog.continue_eval(ctx, response)
- self.cog.get_code.assert_awaited_once_with(new_msg)
+ self.cog.get_code.assert_awaited_once_with(new_msg, ctx.command)
self.assertEqual(actual, expected)
self.bot.wait_for.assert_has_awaits(
(
@@ -343,7 +341,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_context.return_value = MockContext(command=command)
message = MockMessage(content=content)
- actual_code = await self.cog.get_code(message)
+ actual_code = await self.cog.get_code(message, self.cog.eval_command)
self.bot.get_context.assert_awaited_once_with(message)
self.assertEqual(actual_code, expected_code)