diff options
-rw-r--r-- | bot/exts/utils/snekbox.py | 47 |
1 files changed, 34 insertions, 13 deletions
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index b16a62479..49f1be17b 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -145,6 +145,24 @@ class Snekbox(Cog): return codeblocks @staticmethod + def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: + """ + Join the codeblocks into a single string, then return the code and the arguments in a tuple. + + If there are multiple codeblocks, insert the first one into the wrapped setup code. + """ + args = ["-m", "timeit"] + setup = "" + if len(codeblocks) > 1: + setup = codeblocks.pop(0) + + code = "\n".join(codeblocks) + + args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + + return code, args + + @staticmethod def get_results_message(results: dict) -> Tuple[str, str]: """Return a user-friendly message and error corresponding to the process's return code.""" stdout, returncode = results["stdout"], results["returncode"] @@ -273,11 +291,14 @@ class Snekbox(Cog): log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") return response - async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]: + async def continue_eval( + self, ctx: Context, response: Message, command: Command + ) -> Optional[tuple[str, Optional[list[str]]]]: """ Check if the eval session should continue. - Return the new code to evaluate or None if the eval session should be terminated. + If the code is to be evaluated, return the new code and the args if the command is the timeit command. + Otherwise return None if the eval session should be terminated. """ _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) @@ -303,9 +324,16 @@ class Snekbox(Cog): except asyncio.TimeoutError: await ctx.message.clear_reaction(REEVAL_EMOJI) - return None + return None, None + + codeblocks = self.prepare_input(code) + + if command is self.timeit_command: + return self.prepare_timeit_input(codeblocks) + else: + return "\n".join(codeblocks), None - return self.prepare_input(code) + return None, None async def get_code(self, message: Message, command: Command) -> Optional[str]: """ @@ -368,7 +396,7 @@ class Snekbox(Cog): finally: del self.jobs[ctx.author.id] - code = await self.continue_eval(ctx, response) + code, args = await self.continue_eval(ctx, response, ctx.command) if not code: break log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") @@ -431,15 +459,8 @@ class Snekbox(Cog): We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ - args = ["-m", "timeit"] - setup = "" codeblocks = self.prepare_input(code) - - if len(codeblocks) > 1: - setup = codeblocks.pop(0) - - code = "\n".join(codeblocks) - args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + code, args = self.prepare_timeit_input(codeblocks) await self.run_eval( ctx, code=code, format_func=self.format_timeit_output, args=args |