diff options
| -rw-r--r-- | bot/exts/backend/error_handler.py | 19 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 6 | 
2 files changed, 16 insertions, 9 deletions
| diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index cc2b5ef56..07248df5b 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -66,7 +66,7 @@ class ErrorHandler(Cog):          if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):              if await self.try_silence(ctx):                  return -            if await self.try_run_eval(ctx): +            if await self.try_run_fixed_codeblock(ctx):                  return              await self.try_get_tag(ctx)  # Try to look for a tag with the command's name          elif isinstance(e, errors.UserInputError): @@ -190,9 +190,9 @@ class ErrorHandler(Cog):          if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):              await self.send_command_suggestion(ctx, ctx.invoked_with) -    async def try_run_eval(self, ctx: Context) -> bool: +    async def try_run_fixed_codeblock(self, ctx: Context) -> bool:          """ -        Attempt to run eval command with backticks directly after command. +        Attempt to run eval or timeit command with triple backticks directly after command.          For example: !eval```print("hi")``` @@ -204,11 +204,18 @@ class ErrorHandler(Cog):          msg.content = command + " " + sep + end          new_ctx = await self.bot.get_context(msg) -        eval_command = self.bot.get_command("eval") -        if eval_command is None or new_ctx.command != eval_command: +        if new_ctx.command is None:              return False -        log.debug("Running fixed eval command.") +        allowed_commands = [ +            self.bot.get_command("eval"), +            self.bot.get_command("timeit"), +        ] + +        if new_ctx.command not in allowed_commands: +            return False + +        log.debug("Running %r command with fixed codeblock.", new_ctx.command.qualified_name)          new_ctx.invoked_from_error_handler = True          await self.bot.invoke(new_ctx) diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index adb0252a5..092de0556 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -47,7 +47,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          )          self.cog.try_silence = AsyncMock()          self.cog.try_get_tag = AsyncMock() -        self.cog.try_run_eval = AsyncMock(return_value=False) +        self.cog.try_run_fixed_codeblock = AsyncMock(return_value=False)          for case in test_cases:              with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): @@ -75,7 +75,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          self.cog.try_silence = AsyncMock()          self.cog.try_get_tag = AsyncMock() -        self.cog.try_run_eval = AsyncMock() +        self.cog.try_run_fixed_codeblock = AsyncMock()          error = errors.CommandNotFound() @@ -83,7 +83,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):          self.cog.try_silence.assert_not_awaited()          self.cog.try_get_tag.assert_not_awaited() -        self.cog.try_run_eval.assert_not_awaited() +        self.cog.try_run_fixed_codeblock.assert_not_awaited()          self.ctx.send.assert_not_awaited()      async def test_error_handler_user_input_error(self): | 
