diff options
| -rw-r--r-- | bot/exts/utils/snekbox.py | 47 | 
1 files changed, 34 insertions, 13 deletions
| diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 0d8da5e56..718cff890 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -145,6 +145,24 @@ class Snekbox(Cog):          return codeblocks      @staticmethod +    def prepare_timeit_input(codeblocks: list[str]) -> tuple[str, list[str]]: +        """ +        Join the codeblocks into a single string, then return the code and the arguments in a tuple. + +        If there are multiple codeblocks, insert the first one into the wrapped setup code. +        """ +        args = ["-m", "timeit"] +        setup = "" +        if len(codeblocks) > 1: +            setup = codeblocks.pop(0) + +        code = "\n".join(codeblocks) + +        args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) + +        return code, args + +    @staticmethod      def get_results_message(results: dict) -> Tuple[str, str]:          """Return a user-friendly message and error corresponding to the process's return code."""          stdout, returncode = results["stdout"], results["returncode"] @@ -273,11 +291,14 @@ class Snekbox(Cog):              log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")          return response -    async def continue_eval(self, ctx: Context, response: Message) -> Optional[list[str]]: +    async def continue_eval( +        self, ctx: Context, response: Message, command: Command +    ) -> Optional[tuple[str, Optional[list[str]]]]:          """          Check if the eval session should continue. -        Return the new code to evaluate or None if the eval session should be terminated. +        If the code is to be evaluated, return the new code and the args if the command is the timeit command. +        Otherwise return None if the eval session should be terminated.          """          _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx)          _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) @@ -303,9 +324,16 @@ class Snekbox(Cog):              except asyncio.TimeoutError:                  await ctx.message.clear_reaction(REEVAL_EMOJI) -                return None +                return None, None + +            codeblocks = self.prepare_input(code) + +            if command is self.timeit_command: +                return self.prepare_timeit_input(codeblocks) +            else: +                return "\n".join(codeblocks), None -            return self.prepare_input(code) +        return None, None      async def get_code(self, message: Message, command: Command) -> Optional[str]:          """ @@ -368,7 +396,7 @@ class Snekbox(Cog):              finally:                  del self.jobs[ctx.author.id] -            code = await self.continue_eval(ctx, response) +            code, args = await self.continue_eval(ctx, response, ctx.command)              if not code:                  break              log.info(f"Re-evaluating code from message {ctx.message.id}:\n{code}") @@ -431,15 +459,8 @@ class Snekbox(Cog):          We've done our best to make this sandboxed, but do let us know if you manage to find an          issue with it!          """ -        args = ["-m", "timeit"] -        setup = ""          codeblocks = self.prepare_input(code) - -        if len(codeblocks) > 1: -            setup = codeblocks.pop(0) - -        code = "\n".join(codeblocks) -        args.extend(["-s", TIMEIT_SETUP_WRAPPER.format(setup=setup)]) +        code, args = self.prepare_timeit_input(codeblocks)          await self.run_eval(              ctx, code=code, format_func=self.format_timeit_output, args=args | 
