aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/error_handler.py40
-rw-r--r--bot/exts/info/tags.py90
-rw-r--r--bot/pagination.py4
-rw-r--r--bot/utils/messages.py2
-rw-r--r--tests/bot/exts/backend/test_error_handler.py50
-rw-r--r--tests/helpers.py11
6 files changed, 125 insertions, 72 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index 561bf8068..6561f84e4 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -2,7 +2,7 @@ import copy
import difflib
import typing as t
-from discord import Embed, Interaction
+from discord import Embed, Interaction, utils
from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors
from pydis_core.site_api import ResponseCodeError
from sentry_sdk import push_scope
@@ -23,8 +23,19 @@ class ErrorHandler(Cog):
self.bot = bot
@staticmethod
- async def _can_run(_: Interaction) -> bool:
- return False
+ async def _can_run(ctx: Context) -> bool:
+ """
+ Add checks for the `get_command_ctx` function here.
+
+ Use discord.utils to run the checks.
+ """
+ checks = []
+ predicates = checks
+ if not predicates:
+ # Since we have no checks, then we just return True.
+ return True
+
+ return await utils.async_all(predicate(ctx) for predicate in predicates)
def _get_error_embed(self, title: str, body: str) -> Embed:
"""Return an embed that contains the exception."""
@@ -164,7 +175,7 @@ class ErrorHandler(Cog):
return True
return False
- async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None:
+ async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None:
"""
Attempt to display a tag by interpreting the command name as a tag name.
@@ -172,29 +183,30 @@ class ErrorHandler(Cog):
by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to
the context to prevent infinite recursion in the case of a CommandNotFound exception.
"""
- tags_get_command = self.bot.get_command("tags get")
- tags_get_command.can_run = can_run if can_run else self._can_run
- if not tags_get_command:
- log.debug("Not attempting to parse message as a tag as could not find `tags get` command.")
+ tags_cog = self.bot.get_cog("Tags")
+ if not tags_cog:
+ log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.")
return
+ tags_get_command = tags_cog.get_command_ctx
+ can_run = can_run if can_run else self._can_run
- interaction.invoked_from_error_handler = True
+ ctx.invoked_from_error_handler = True
log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
try:
- if not await tags_get_command.can_run(interaction):
+ if not await can_run(ctx):
log.debug(log_msg)
return
except errors.CommandError as tag_error:
log.debug(log_msg)
- await self.on_command_error(interaction, tag_error)
+ await self.on_command_error(ctx, tag_error)
return
- if await interaction.invoke(tags_get_command, tag_name=interaction.message.content):
+ if await tags_get_command(ctx, ctx.message.content):
return
- if not any(role.id in MODERATION_ROLES for role in interaction.user.roles):
- await self.send_command_suggestion(interaction, interaction.invoked_with)
+ if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
+ await self.send_command_suggestion(ctx, ctx.invoked_with)
async def try_run_eval(self, ctx: Context) -> bool:
"""
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index 25c51def9..60f730586 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -8,7 +8,7 @@ from typing import Literal, NamedTuple, Optional, Union
import discord
import frontmatter
-from discord import Embed, Member, app_commands
+from discord import Embed, Interaction, Member, app_commands
from discord.ext.commands import Cog
from bot import constants
@@ -140,15 +140,8 @@ class Tags(Cog):
self.bot = bot
self.tags: dict[TagIdentifier, Tag] = {}
self.initialize_tags()
- self.bot.tree.copy_global_to(guild=discord.Object(id=GUILD_ID))
tag_group = app_commands.Group(name="tag", description="...")
- # search_tag = app_commands.Group(name="search", description="...", parent=tag_group)
-
- @Cog.listener()
- async def on_ready(self) -> None:
- """Called when the cog is ready."""
- await self.bot.tree.sync(guild=discord.Object(id=GUILD_ID))
def initialize_tags(self) -> None:
"""Load all tags from resources into `self.tags`."""
@@ -195,7 +188,8 @@ class Tags(Cog):
async def get_tag_embed(
self,
- interaction: discord.Interaction,
+ author: discord.Member,
+ channel: discord.TextChannel | discord.Thread,
tag_identifier: TagIdentifier,
) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:
"""
@@ -206,7 +200,7 @@ class Tags(Cog):
filtered_tags = [
(ident, tag) for ident, tag in
self.get_fuzzy_matches(tag_identifier)[:10]
- if tag.accessible_by(interaction.user)
+ if tag.accessible_by(author)
]
# Try exact match, includes checking through alt names
@@ -225,10 +219,10 @@ class Tags(Cog):
tag = filtered_tags[0][1]
if tag is not None:
- if tag.on_cooldown_in(interaction.channel):
+ if tag.on_cooldown_in(channel):
log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.")
return COOLDOWN.obj
- tag.set_cooldown_for(interaction.channel)
+ tag.set_cooldown_for(channel)
self.bot.stats.incr(
f"tags.usages"
@@ -243,7 +237,7 @@ class Tags(Cog):
suggested_tags_text = "\n".join(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
for identifier, tag in filtered_tags
- if not tag.on_cooldown_in(interaction.channel)
+ if not tag.on_cooldown_in(channel)
)
return Embed(
title="Did you mean ...",
@@ -292,8 +286,37 @@ class Tags(Cog):
if identifier.group == group and tag.accessible_by(user)
)
+ async def get_command_ctx(
+ self,
+ ctx: discord.Context,
+ name: str
+ ) -> bool:
+ """Made specifically for `error_handler.py`, See `get_command` for more info."""
+ identifier = TagIdentifier.from_string(name)
+
+ if identifier.group is None:
+ # Try to find accessible tags from a group matching the identifier's name.
+ if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author):
+ await LinePaginator.paginate(
+ group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
+ )
+ return True
+
+ embed = await self.get_tag_embed(ctx.author, ctx.channel, identifier)
+ if embed is None:
+ return False
+
+ if embed is not COOLDOWN.obj:
+
+ await wait_for_deletion(
+ await ctx.send(embed=embed),
+ (ctx.author.id,)
+ )
+ # A valid tag was found and was either sent, or is on cooldown
+ return True
+
@tag_group.command(name="get")
- async def get_command(self, interaction: discord.Interaction, *, tag_name: Optional[str]) -> bool:
+ async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool:
"""
If a single argument matching a group name is given, list all accessible tags from that group
Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name.
@@ -303,7 +326,7 @@ class Tags(Cog):
Returns True if a message was sent, or if the tag is on cooldown.
Returns False if no message was sent.
""" # noqa: D205, D415
- if not tag_name:
+ if not name:
if self.tags:
await LinePaginator.paginate(
self.accessible_tags(interaction.user),
@@ -314,7 +337,7 @@ class Tags(Cog):
await interaction.response.send_message(embed=Embed(description="**There are no tags!**"))
return True
- identifier = TagIdentifier.from_string(tag_name)
+ identifier = TagIdentifier.from_string(name)
if identifier.group is None:
# Try to find accessible tags from a group matching the identifier's name.
@@ -324,33 +347,43 @@ class Tags(Cog):
)
return True
- embed = await self.get_tag_embed(interaction, identifier)
+ embed = await self.get_tag_embed(interaction.user, interaction.channel, identifier)
+ ephemeral = False
if embed is None:
- return False
-
- if embed is not COOLDOWN.obj:
+ description = f"**There are no tags matching the name {name!r}!**"
+ embed = Embed(description=description)
+ ephemeral = True
+ elif embed is COOLDOWN.obj:
+ description = f"Tag {name!r} is on cooldown."
+ embed = Embed(description=description)
+ ephemeral = True
+
+ await interaction.response.send_message(embed=embed, ephemeral=ephemeral)
+ if not ephemeral:
await wait_for_deletion(
- await interaction.response.send_message(embed=embed),
+ await interaction.original_response(),
(interaction.user.id,)
)
+
# A valid tag was found and was either sent, or is on cooldown
return True
- @get_command.autocomplete("tag_name")
- async def tag_name_autocomplete(
+ @get_command.autocomplete("name")
+ async def name_autocomplete(
self,
- interaction: discord.Interaction,
+ interaction: Interaction,
current: str
) -> list[app_commands.Choice[str]]:
"""Autocompleter for `/tag get` command."""
- tag_names = [tag.name for tag in self.tags.keys()]
- return [
+ names = [tag.name for tag in self.tags.keys()]
+ choices = [
app_commands.Choice(name=tag, value=tag)
- for tag in tag_names if current.lower() in tag
+ for tag in names if current.lower() in tag
]
+ return choices[:25] if len(choices) > 25 else choices
@tag_group.command(name="list")
- async def list_command(self, interaction: discord.Interaction) -> bool:
+ async def list_command(self, interaction: Interaction) -> bool:
"""Lists all accessible tags."""
if self.tags:
await LinePaginator.paginate(
@@ -367,4 +400,3 @@ class Tags(Cog):
async def setup(bot: Bot) -> None:
"""Load the Tags cog."""
await bot.add_cog(Tags(bot))
- await bot.tree.sync(guild=discord.Object(id=GUILD_ID))
diff --git a/bot/pagination.py b/bot/pagination.py
index 1c63a4768..c39ce211b 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -190,8 +190,8 @@ class LinePaginator(Paginator):
@classmethod
async def paginate(
cls,
- lines: t.List[str],
- ctx: t.Union[Context, discord.Interaction],
+ lines: list[str],
+ ctx: Context | discord.Interaction,
embed: discord.Embed,
prefix: str = "",
suffix: str = "",
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index 27f2eac97..f6bdceaef 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -58,7 +58,7 @@ def reaction_check(
async def wait_for_deletion(
- message: discord.Message,
+ message: discord.Message | discord.InteractionMessage,
user_ids: Sequence[int],
deletion_emojis: Sequence[str] = (Emojis.trashcan,),
timeout: float = 60 * 5,
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 83bc3c4a1..14e7a4125 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -9,7 +9,7 @@ from bot.exts.backend import error_handler
from bot.exts.info.tags import Tags
from bot.exts.moderation.silence import Silence
from bot.utils.checks import InWhitelistCheckFailure
-from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel
+from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel
class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
@@ -331,65 +331,65 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = MockBot()
- self.interaction = MockInteraction()
+ self.ctx = MockContext()
self.tag = Tags(self.bot)
self.cog = error_handler.ErrorHandler(self.bot)
- self.bot.get_command.return_value = self.tag.get_command
+ self.bot.get_cog.return_value = self.tag
async def test_try_get_tag_get_command(self):
"""Should call `Bot.get_command` with `tags get` argument."""
- self.bot.get_command.reset_mock()
- await self.cog.try_get_tag(self.interaction)
- self.bot.get_command.assert_called_once_with("tags get")
+ self.bot.get_cog.reset_mock()
+ await self.cog.try_get_tag(self.ctx)
+ self.bot.get_cog.assert_called_once_with("Tags")
async def test_try_get_tag_invoked_from_error_handler(self):
- """`self.interaction` should have `invoked_from_error_handler` `True`."""
- self.interaction.invoked_from_error_handler = False
- await self.cog.try_get_tag(self.interaction)
- self.assertTrue(self.interaction.invoked_from_error_handler)
+ """`self.ctx` should have `invoked_from_error_handler` `True`."""
+ self.ctx.invoked_from_error_handler = False
+ await self.cog.try_get_tag(self.ctx)
+ self.assertTrue(self.ctx.invoked_from_error_handler)
async def test_try_get_tag_no_permissions(self):
"""Test how to handle checks failing."""
self.tag.get_command.can_run = AsyncMock(return_value=False)
- self.interaction.invoked_with = "foo"
- self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False)))
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False)))
async def test_try_get_tag_command_error(self):
"""Should call `on_command_error` when `CommandError` raised."""
err = errors.CommandError()
self.tag.get_command.can_run = AsyncMock(side_effect=err)
self.cog.on_command_error = AsyncMock()
- self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err)))
- self.cog.on_command_error.assert_awaited_once_with(self.interaction, err)
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err)))
+ self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
async def test_dont_call_suggestion_tag_sent(self):
"""Should never call command suggestion if tag is already sent."""
- self.interaction.message = MagicMock(content="foo")
- self.interaction.invoke = AsyncMock(return_value=True)
+ self.ctx.message = MagicMock(content="foo")
+ self.tag.get_command_ctx = AsyncMock(return_value=True)
self.cog.send_command_suggestion = AsyncMock()
- await self.cog.try_get_tag(self.interaction, AsyncMock())
+ await self.cog.try_get_tag(self.ctx)
self.cog.send_command_suggestion.assert_not_awaited()
@patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234])
async def test_dont_call_suggestion_if_user_mod(self):
"""Should not call command suggestion if user is a mod."""
- self.interaction.invoked_with = "foo"
- self.interaction.invoke = AsyncMock(return_value=False)
- self.interaction.user.roles = [MockRole(id=1234)]
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.ctx.author.roles = [MockRole(id=1234)]
self.cog.send_command_suggestion = AsyncMock()
- await self.cog.try_get_tag(self.interaction, AsyncMock())
+ await self.cog.try_get_tag(self.ctx)
self.cog.send_command_suggestion.assert_not_awaited()
async def test_call_suggestion(self):
"""Should call command suggestion if user is not a mod."""
- self.interaction.invoked_with = "foo"
- self.interaction.invoke = AsyncMock(return_value=False)
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
self.cog.send_command_suggestion = AsyncMock()
- await self.cog.try_get_tag(self.interaction, AsyncMock())
- self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo")
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo")
class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
diff --git a/tests/helpers.py b/tests/helpers.py
index 2d20b4d07..0d955b521 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -486,7 +486,6 @@ class MockInteraction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Interaction`
instances. For more information, see the `MockGuild` docstring.
"""
- # spec_set = context_instance
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@@ -550,6 +549,16 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
self.channel = kwargs.get('channel', MockTextChannel())
+class MockInteractionMessage(MockMessage):
+ """
+ A MagicMock subclass to mock InteractionMessage objects.
+
+ Instances of this class will follow the specifications of `discord.InteractionMessage` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ pass
+
+
emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'}
emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data)