aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar ToxicKidz <[email protected]>2021-05-20 14:07:26 -0400
committerGravatar ToxicKidz <[email protected]>2021-05-20 14:07:26 -0400
commitcd921fbdb73ffb0b83ab83ea2cc004de1777724a (patch)
tree857f5aa152780e87df3fc5388bbd4428761514f3
parentMerge pull request #1593 from python-discord/flake-8-isn't-a-task (diff)
feat: Add the timeit command
-rw-r--r--bot/exts/utils/snekbox.py142
1 files changed, 107 insertions, 35 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index b1f1ba6a8..615956637 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -6,10 +6,10 @@ import re
import textwrap
from functools import partial
from signal import Signals
-from typing import Optional, Tuple
+from typing import Awaitable, Callable, Optional, Tuple
from discord import HTTPException, Message, NotFound, Reaction, User
-from discord.ext.commands import Cog, Context, command, guild_only
+from discord.ext.commands import Cog, Command, Context, command, guild_only
from bot.bot import Bot
from bot.constants import Categories, Channels, Roles, URLs
@@ -36,6 +36,17 @@ RAW_CODE_REGEX = re.compile(
re.DOTALL # "." also matches newlines
)
+TIMEIT_EVAL_WRAPPER = """
+from contextlib import redirect_stdout
+from io import StringIO
+
+with redirect_stdout(StringIO()):
+ del redirect_stdout, StringIO
+{code}
+"""
+
+TIMEIT_OUTPUT_REGEX = re.compile(r"\d+ loops, best of \d+: \d(?:\.\d\d?)? [mnu]?sec per loop")
+
MAX_PASTE_LEN = 10000
# `!eval` command whitelists and blacklists.
@@ -48,6 +59,8 @@ SIGKILL = 9
REEVAL_EMOJI = '\U0001f501' # :repeat:
REEVAL_TIMEOUT = 30
+FormatFunc = Callable[[str], Awaitable[tuple[str, Optional[str]]]]
+
class Snekbox(Cog):
"""Safe evaluation of Python code using Snekbox."""
@@ -56,10 +69,14 @@ class Snekbox(Cog):
self.bot = bot
self.jobs = {}
- async def post_eval(self, code: str) -> dict:
+ async def post_eval(self, code: str, *, args: Optional[list[str]]) -> dict:
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
url = URLs.snekbox_eval_api
data = {"input": code}
+
+ if args is not None:
+ data["args"] = args
+
async with self.bot.http_session.post(url, json=data, raise_for_status=True) as resp:
return await resp.json()
@@ -144,8 +161,6 @@ class Snekbox(Cog):
Prepend each line with a line number. Truncate if there are over 10 lines or 1000 characters
and upload the full output to a paste service.
"""
- log.trace("Formatting output...")
-
output = output.rstrip("\n")
original_output = output # To be uploaded to a pasting service if needed
paste_link = None
@@ -185,20 +200,28 @@ class Snekbox(Cog):
return output, paste_link
- async def send_eval(self, ctx: Context, code: str) -> Message:
+ async def send_eval(
+ self,
+ ctx: Context,
+ code: str,
+ *,
+ args: Optional[list[str]],
+ format_func: FormatFunc
+ ) -> Message:
"""
Evaluate code, format it, and send the output to the corresponding channel.
Return the bot response.
"""
async with ctx.typing():
- results = await self.post_eval(code)
+ results = await self.post_eval(code, args=args)
msg, error = self.get_results_message(results)
if error:
output, paste_link = error, None
else:
- output, paste_link = await self.format_output(results["stdout"])
+ log.trace("Formatting output...")
+ output, paste_link = await format_func(results["stdout"])
icon = self.get_status_emoji(results)
msg = f"{ctx.author.mention} {icon} {msg}.\n\n```\n{output}\n```"
@@ -247,7 +270,7 @@ class Snekbox(Cog):
timeout=10
)
- code = await self.get_code(new_message)
+ code = await self.get_code(new_message, ctx.command)
await ctx.message.clear_reaction(REEVAL_EMOJI)
with contextlib.suppress(HTTPException):
await response.delete()
@@ -256,9 +279,9 @@ class Snekbox(Cog):
await ctx.message.clear_reaction(REEVAL_EMOJI)
return None
- return code
+ return self.prepare_input(code)
- async def get_code(self, message: Message) -> Optional[str]:
+ async def get_code(self, message: Message, command: Command) -> Optional[str]:
"""
Return the code from `message` to be evaluated.
@@ -268,7 +291,7 @@ class Snekbox(Cog):
log.trace(f"Getting context for message {message.id}.")
new_ctx = await self.bot.get_context(message)
- if new_ctx.command is self.eval_command:
+ if new_ctx.command is command:
log.trace(f"Message {message.id} invokes eval command.")
split = message.content.split(maxsplit=1)
code = split[1] if len(split) > 1 else None
@@ -278,25 +301,18 @@ class Snekbox(Cog):
return code
- @command(name="eval", aliases=("e",))
- @guild_only()
- @redirect_output(
- destination_channel=Channels.bot_commands,
- bypass_roles=EVAL_ROLES,
- categories=NO_EVAL_CATEGORIES,
- channels=NO_EVAL_CHANNELS,
- ping_user=False
- )
- async def eval_command(self, ctx: Context, *, code: str = None) -> None:
+ async def run_eval(
+ self,
+ ctx: Context,
+ code: str,
+ format_func: FormatFunc,
+ *,
+ args: Optional[list[str]] = None,
+ ) -> None:
"""
- Run Python code and get the results.
+ Handles checks, stats and re-evaluation of an eval.
- This command supports multiple lines of code, including code wrapped inside a formatted code
- block. Code can be re-evaluated by editing the original message within 10 seconds and
- clicking the reaction that subsequently appears.
-
- We've done our best to make this sandboxed, but do let us know if you manage to find an
- issue with it!
+ `format_func` is an async callable that takes a string (the output) and formats it to show to the user.
"""
if ctx.author.id in self.jobs:
await ctx.send(
@@ -305,10 +321,6 @@ class Snekbox(Cog):
)
return
- if not code: # None or empty string
- await ctx.send_help(ctx.command)
- return
-
if Roles.helpers in (role.id for role in ctx.author.roles):
self.bot.stats.incr("snekbox_usages.roles.helpers")
else:
@@ -325,9 +337,8 @@ class Snekbox(Cog):
while True:
self.jobs[ctx.author.id] = datetime.datetime.now()
- code = self.prepare_input(code)
try:
- response = await self.send_eval(ctx, code)
+ response = await self.send_eval(ctx, code, args=args, format_func=format_func)
finally:
del self.jobs[ctx.author.id]
@@ -336,6 +347,67 @@ class Snekbox(Cog):
break
log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
+ async def format_timeit_output(self, output: str) -> tuple[str, str]:
+ """
+ Parses the time from the end of the output given by timeit.
+
+ If an error happened, then it won't contain the time and instead proceed with regular formatting.
+ """
+ split_output = output.rstrip("\n").rsplit("\n", 1)
+ if len(split_output) == 2 and TIMEIT_OUTPUT_REGEX.fullmatch(split_output[1]):
+ return split_output[1], None
+
+ return await self.format_output(output)
+
+ @command(name="eval", aliases=("e",))
+ @guild_only()
+ @redirect_output(
+ destination_channel=Channels.bot_commands,
+ bypass_roles=EVAL_ROLES,
+ categories=NO_EVAL_CATEGORIES,
+ channels=NO_EVAL_CHANNELS,
+ ping_user=False
+ )
+ async def eval_command(self, ctx: Context, *, code: str) -> None:
+ """
+ Run Python code and get the results.
+
+ This command supports multiple lines of code, including code wrapped inside a formatted code
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
+ issue with it!
+ """
+ code = self.prepare_input(code)
+ await self.run_eval(ctx, code, format_func=self.format_output)
+
+ @command(name="timeit", aliases=("ti",))
+ @guild_only()
+ @redirect_output(
+ destination_channel=Channels.bot_commands,
+ bypass_roles=EVAL_ROLES,
+ categories=NO_EVAL_CATEGORIES,
+ channels=NO_EVAL_CHANNELS,
+ ping_user=False
+ )
+ async def timeit_command(self, ctx: Context, *, code: str) -> str:
+ """
+ Profile Python Code to find execution time.
+
+ This command supports multiple lines of code, including code wrapped inside a formatted code
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
+ issue with it!
+ """
+ code = self.prepare_input(code)
+ await self.run_eval(
+ ctx, TIMEIT_EVAL_WRAPPER.format(code=textwrap.indent(code, " ")),
+ format_func=self.format_timeit_output, args=["-m", "timeit"]
+ )
+
def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
"""Return True if the edited message is the context message and the content was indeed modified."""