diff options
-rw-r--r-- | bot/exts/backend/error_handler.py | 22 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 44 | ||||
-rw-r--r-- | tests/helpers.py | 20 |
3 files changed, 56 insertions, 30 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index cc2b5ef56..561bf8068 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,8 @@ import copy import difflib +import typing as t -from discord import Embed +from discord import Embed, Interaction from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -21,6 +22,10 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot + @staticmethod + async def _can_run(_: Interaction) -> bool: + return False + def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -159,7 +164,7 @@ class ErrorHandler(Cog): return True return False - async def try_get_tag(self, ctx: Context) -> None: + async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -168,27 +173,28 @@ class ErrorHandler(Cog): the context to prevent infinite recursion in the case of a CommandNotFound exception. """ tags_get_command = self.bot.get_command("tags get") + tags_get_command.can_run = can_run if can_run else self._can_run if not tags_get_command: log.debug("Not attempting to parse message as a tag as could not find `tags get` command.") return - ctx.invoked_from_error_handler = True + interaction.invoked_from_error_handler = True log_msg = "Cancelling attempt to fall back to a tag due to failed checks." try: - if not await tags_get_command.can_run(ctx): + if not await tags_get_command.can_run(interaction): log.debug(log_msg) return except errors.CommandError as tag_error: log.debug(log_msg) - await self.on_command_error(ctx, tag_error) + await self.on_command_error(interaction, tag_error) return - if await ctx.invoke(tags_get_command, argument_string=ctx.message.content): + if await interaction.invoke(tags_get_command, tag_name=interaction.message.content): return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in interaction.user.roles): + await self.send_command_suggestion(interaction, interaction.invoked_with) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index adb0252a5..83bc3c4a1 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.backend import error_handler from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel +from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -331,7 +331,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() - self.ctx = MockContext() + self.interaction = MockInteraction() self.tag = Tags(self.bot) self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command @@ -339,57 +339,57 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" self.bot.get_command.reset_mock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction) self.bot.get_command.assert_called_once_with("tags get") async def test_try_get_tag_invoked_from_error_handler(self): - """`self.ctx` should have `invoked_from_error_handler` `True`.""" - self.ctx.invoked_from_error_handler = False - await self.cog.try_get_tag(self.ctx) - self.assertTrue(self.ctx.invoked_from_error_handler) + """`self.interaction` should have `invoked_from_error_handler` `True`.""" + self.interaction.invoked_from_error_handler = False + await self.cog.try_get_tag(self.interaction) + self.assertTrue(self.interaction.invoked_from_error_handler) async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" self.tag.get_command.can_run = AsyncMock(return_value=False) - self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.interaction.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False))) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err))) + self.cog.on_command_error.assert_awaited_once_with(self.interaction, err) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.ctx.message = MagicMock(content="foo") - self.ctx.invoke = AsyncMock(return_value=True) + self.interaction.message = MagicMock(content="foo") + self.interaction.invoke = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) async def test_dont_call_suggestion_if_user_mod(self): """Should not call command suggestion if user is a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) - self.ctx.author.roles = [MockRole(id=1234)] + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) + self.interaction.user.roles = [MockRole(id=1234)] self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) - self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + await self.cog.try_get_tag(self.interaction, AsyncMock()) + self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo") class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 4b980ac21..2d20b4d07 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -479,6 +479,26 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) +class MockInteraction(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Interaction objects. + + Instances of this class will follow the specifications of `discord.Interaction` + instances. For more information, see the `MockGuild` docstring. + """ + # spec_set = context_instance + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.me = kwargs.get('me', MockMember()) + self.client = kwargs.get('client', MockBot()) + self.guild = kwargs.get('guild', MockGuild()) + self.user = kwargs.get('user', MockMember()) + self.channel = kwargs.get('channel', MockTextChannel()) + self.message = kwargs.get('message', MockMessage()) + self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) + + attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) |