aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/utils/snekbox.py47
1 files changed, 34 insertions, 13 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index b16a62479..49f1be17b 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -145,6 +145,24 @@ class Snekbox(Cog):
return codeblocks
@staticmethod
+ def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]:
+ """
+ Join the codeblocks into a single string, then return the code and the arguments in a tuple.
+
+ If there are multiple codeblocks, insert the first one into the wrapped setup code.
+ """
+ args = ["-m", "timeit"]
+ setup = ""
+ if len(codeblocks) > 1:
+ setup = codeblocks.pop(0)
+
+ code = "\n".join(codeblocks)
+
+ args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
+
+ return code, args
+
+ @staticmethod
def get_results_message(results: dict) -> Tuple[str, str]:
"""Return a user-friendly message and error corresponding to the process's return code."""
stdout, returncode = results["stdout"], results["returncode"]
@@ -273,11 +291,14 @@ class Snekbox(Cog):
log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
return response
- async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]:
+ async def continue_eval(
+ self, ctx: Context, response: Message, command: Command
+ ) -> Optional[tuple[str, Optional[list[str]]]]:
"""
Check if the eval session should continue.
- Return the new code to evaluate or None if the eval session should be terminated.
+ If the code is to be evaluated, return the new code and the args if the command is the timeit command.
+ Otherwise return None if the eval session should be terminated.
"""
_predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx)
_predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx)
@@ -303,9 +324,16 @@ class Snekbox(Cog):
except asyncio.TimeoutError:
await ctx.message.clear_reaction(REEVAL_EMOJI)
- return None
+ return None, None
+
+ codeblocks = self.prepare_input(code)
+
+ if command is self.timeit_command:
+ return self.prepare_timeit_input(codeblocks)
+ else:
+ return "\n".join(codeblocks), None
- return self.prepare_input(code)
+ return None, None
async def get_code(self, message: Message, command: Command) -> Optional[str]:
"""
@@ -368,7 +396,7 @@ class Snekbox(Cog):
finally:
del self.jobs[ctx.author.id]
- code = await self.continue_eval(ctx, response)
+ code, args = await self.continue_eval(ctx, response, ctx.command)
if not code:
break
log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}")
@@ -431,15 +459,8 @@ class Snekbox(Cog):
We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
- args = ["-m", "timeit"]
- setup = ""
codeblocks = self.prepare_input(code)
-
- if len(codeblocks) > 1:
- setup = codeblocks.pop(0)
-
- code = "\n".join(codeblocks)
- args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)])
+ code, args = self.prepare_timeit_input(codeblocks)
await self.run_eval(
ctx, code=code, format_func=self.format_timeit_output, args=args