aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/error_handler.py37
-rw-r--r--bot/exts/info/tags.py205
-rw-r--r--bot/pagination.py22
-rw-r--r--bot/utils/messages.py2
-rw-r--r--tests/bot/exts/backend/test_error_handler.py14
-rw-r--r--tests/helpers.py29
6 files changed, 167 insertions, 142 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index 07248df5b..8883f7566 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -1,7 +1,7 @@
import copy
import difflib
-from discord import Embed
+from discord import Embed, Member
from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors
from pydis_core.site_api import ResponseCodeError
from sentry_sdk import push_scope
@@ -167,28 +167,33 @@ class ErrorHandler(Cog):
by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to
the context to prevent infinite recursion in the case of a CommandNotFound exception.
"""
- tags_get_command = self.bot.get_command("tags get")
- if not tags_get_command:
- log.debug("Not attempting to parse message as a tag as could not find `tags get` command.")
+ tags_cog = self.bot.get_cog("Tags")
+ if not tags_cog:
+ log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.")
return
+ tags_get_command = tags_cog.get_command_ctx
- ctx.invoked_from_error_handler = True
+ maybe_tag_name = ctx.invoked_with
+ if not maybe_tag_name or not isinstance(ctx.author, Member):
+ return
- log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
+ ctx.invoked_from_error_handler = True
try:
- if not await tags_get_command.can_run(ctx):
- log.debug(log_msg)
+ if not await self.bot.can_run(ctx):
+ log.debug("Cancelling attempt to fall back to a tag due to failed checks.")
return
- except errors.CommandError as tag_error:
- log.debug(log_msg)
- await self.on_command_error(ctx, tag_error)
- return
- if await ctx.invoke(tags_get_command, argument_string=ctx.message.content):
- return
+ if await tags_get_command(ctx, maybe_tag_name):
+ return
- if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
- await self.send_command_suggestion(ctx, ctx.invoked_with)
+ if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
+ await self.send_command_suggestion(ctx, maybe_tag_name)
+ except Exception as err:
+ log.debug("Error while attempting to invoke tag fallback.")
+ if isinstance(err, errors.CommandError):
+ await self.on_command_error(ctx, err)
+ else:
+ await self.on_command_error(ctx, errors.CommandInvokeError(err))
async def try_run_fixed_codeblock(self, ctx: Context) -> bool:
"""
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index 83d3a9d93..309f22cad 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -4,12 +4,12 @@ import enum
import re
import time
from pathlib import Path
-from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union
+from typing import Literal, NamedTuple, Optional, Union
import discord
import frontmatter
-from discord import Embed, Member
-from discord.ext.commands import Cog, Context, group
+from discord import Embed, Interaction, Member, app_commands
+from discord.ext.commands import Cog, Context
from bot import constants
from bot.bot import Bot
@@ -91,7 +91,7 @@ class Tag:
embed.description = self.content
return embed
- def accessible_by(self, member: discord.Member) -> bool:
+ def accessible_by(self, member: Member) -> bool:
"""Check whether `member` can access the tag."""
return bool(
not self._restricted_to
@@ -182,101 +182,22 @@ class Tags(Cog):
return suggestions
- def _get_tags_via_content(
- self,
- check: Callable[[Iterable], bool],
- keywords: str,
- user: Member,
- ) -> list[tuple[TagIdentifier, Tag]]:
- """
- Search for tags via contents.
-
- `predicate` will be the built-in any, all, or a custom callable. Must return a bool.
- """
- keywords_processed = []
- for keyword in keywords.split(","):
- keyword_sanitized = keyword.strip().casefold()
- if not keyword_sanitized:
- # this happens when there are leading / trailing / consecutive comma.
- continue
- keywords_processed.append(keyword_sanitized)
-
- if not keywords_processed:
- # after sanitizing, we can end up with an empty list, for example when keywords is ","
- # in that case, we simply want to search for such keywords directly instead.
- keywords_processed = [keywords]
-
- matching_tags = []
- for identifier, tag in self.tags.items():
- matches = (query in tag.content.casefold() for query in keywords_processed)
- if tag.accessible_by(user) and check(matches):
- matching_tags.append((identifier, tag))
-
- return matching_tags
-
- async def _send_matching_tags(
- self,
- ctx: Context,
- keywords: str,
- matching_tags: list[tuple[TagIdentifier, Tag]],
- ) -> None:
- """Send the result of matching tags to user."""
- if len(matching_tags) == 1:
- await ctx.send(embed=matching_tags[0][1].embed)
- elif matching_tags:
- is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0
- embed = Embed(
- title=f"Here are the tags containing the given keyword{'s' * is_plural}:",
- )
- await LinePaginator.paginate(
- sorted(
- f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}"
- for identifier, _ in matching_tags
- ),
- ctx,
- embed,
- **self.PAGINATOR_DEFAULTS,
- )
-
- @group(name="tags", aliases=("tag", "t"), invoke_without_command=True, usage="[tag_group] [tag_name]")
- async def tags_group(self, ctx: Context, *, argument_string: Optional[str]) -> None:
- """Show all known tags, a single tag, or run a subcommand."""
- await self.get_command(ctx, argument_string=argument_string)
-
- @tags_group.group(name="search", invoke_without_command=True)
- async def search_tag_content(self, ctx: Context, *, keywords: str) -> None:
- """
- Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
-
- Only search for tags that has ALL the keywords.
- """
- matching_tags = self._get_tags_via_content(all, keywords, ctx.author)
- await self._send_matching_tags(ctx, keywords, matching_tags)
-
- @search_tag_content.command(name="any")
- async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None:
- """
- Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
-
- Search for tags that has ANY of the keywords.
- """
- matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author)
- await self._send_matching_tags(ctx, keywords, matching_tags)
-
async def get_tag_embed(
self,
- ctx: Context,
+ member: Member,
+ channel: discord.abc.Messageable,
tag_identifier: TagIdentifier,
) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:
"""
- Generate an embed of the requested tag or of suggestions if the tag doesn't exist/isn't accessible by the user.
+ Generate an embed of the requested tag or of suggestions if the tag doesn't exist
+ or isn't accessible by the member.
If the requested tag is on cooldown return `COOLDOWN.obj`, otherwise if no suggestions were found return None.
- """
+ """ # noqa: D205, D415
filtered_tags = [
(ident, tag) for ident, tag in
self.get_fuzzy_matches(tag_identifier)[:10]
- if tag.accessible_by(ctx.author)
+ if tag.accessible_by(member)
]
# Try exact match, includes checking through alt names
@@ -295,10 +216,10 @@ class Tags(Cog):
tag = filtered_tags[0][1]
if tag is not None:
- if tag.on_cooldown_in(ctx.channel):
+ if tag.on_cooldown_in(channel):
log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.")
return COOLDOWN.obj
- tag.set_cooldown_for(ctx.channel)
+ tag.set_cooldown_for(channel)
self.bot.stats.incr(
f"tags.usages"
@@ -313,15 +234,15 @@ class Tags(Cog):
suggested_tags_text = "\n".join(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
for identifier, tag in filtered_tags
- if not tag.on_cooldown_in(ctx.channel)
+ if not tag.on_cooldown_in(channel)
)
return Embed(
title="Did you mean ...",
description=suggested_tags_text
)
- def accessible_tags(self, user: Member) -> list[str]:
- """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted."""
+ def accessible_tags(self, member: Member) -> list[str]:
+ """Return a formatted list of tags that are accessible by `member`; groups first, and alphabetically sorted."""
def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str:
group, name = tag_item[0]
if group is None:
@@ -338,7 +259,7 @@ class Tags(Cog):
if identifier.group != current_group:
if not group_accessible:
- # Remove group separator line if no tags in the previous group were accessible by the user.
+ # Remove group separator line if no tags in the previous group were accessible by the member.
result_lines.pop()
# A new group began, add a separator with the group name.
current_group = identifier.group
@@ -348,22 +269,55 @@ class Tags(Cog):
else:
result_lines.append("\n\N{BULLET}")
- if tag.accessible_by(user):
+ if tag.accessible_by(member):
result_lines.append(f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}")
group_accessible = True
return result_lines
- def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]:
- """Return a formatted list of tags in `group`, that are accessible by `user`."""
+ def accessible_tags_in_group(self, group: str, member: Member) -> list[str]:
+ """Return a formatted list of tags in `group`, that are accessible by `member`."""
return sorted(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
for identifier, tag in self.tags.items()
- if identifier.group == group and tag.accessible_by(user)
+ if identifier.group == group and tag.accessible_by(member)
)
- @tags_group.command(name="get", aliases=("show", "g"), usage="[tag_group] [tag_name]")
- async def get_command(self, ctx: Context, *, argument_string: Optional[str]) -> bool:
+ async def get_command_ctx(
+ self,
+ ctx: Context,
+ name: str
+ ) -> bool:
+ """
+ Made specifically for `ErrorHandler().try_get_tag` to handle sending tags through ctx.
+
+ See `get_command` for more info, but here name is not optional unlike `get_command`.
+ """
+ identifier = TagIdentifier.from_string(name)
+
+ if identifier.group is None:
+ # Try to find accessible tags from a group matching the identifier's name.
+ if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author):
+ await LinePaginator.paginate(
+ group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
+ )
+ return True
+
+ embed = await self.get_tag_embed(ctx.author, ctx.channel, identifier)
+ if embed is None:
+ return False
+
+ if embed is not COOLDOWN.obj:
+
+ await wait_for_deletion(
+ await ctx.send(embed=embed),
+ (ctx.author.id,)
+ )
+ # A valid tag was found and was either sent, or is on cooldown
+ return True
+
+ @app_commands.command(name="tag")
+ async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool:
"""
If a single argument matching a group name is given, list all accessible tags from that group
Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name.
@@ -373,37 +327,62 @@ class Tags(Cog):
Returns True if a message was sent, or if the tag is on cooldown.
Returns False if no message was sent.
""" # noqa: D205, D415
- if not argument_string:
+ if not name:
if self.tags:
await LinePaginator.paginate(
- self.accessible_tags(ctx.author), ctx, Embed(title="Available tags"), **self.PAGINATOR_DEFAULTS
+ self.accessible_tags(interaction.user),
+ interaction, Embed(title="Available tags"),
+ **self.PAGINATOR_DEFAULTS,
)
else:
- await ctx.send(embed=Embed(description="**There are no tags!**"))
+ await interaction.response.send_message(embed=Embed(description="**There are no tags!**"))
return True
- identifier = TagIdentifier.from_string(argument_string)
+ identifier = TagIdentifier.from_string(name)
if identifier.group is None:
# Try to find accessible tags from a group matching the identifier's name.
- if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author):
+ if group_tags := self.accessible_tags_in_group(identifier.name, interaction.user):
await LinePaginator.paginate(
- group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
+ group_tags, interaction, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
)
return True
- embed = await self.get_tag_embed(ctx, identifier)
+ embed = await self.get_tag_embed(interaction.user, interaction.channel, identifier)
+ ephemeral = False
if embed is None:
- return False
-
- if embed is not COOLDOWN.obj:
+ description = f"**There are no tags matching the name {name!r}!**"
+ embed = Embed(description=description)
+ ephemeral = True
+ elif embed is COOLDOWN.obj:
+ description = f"Tag {name!r} is on cooldown."
+ embed = Embed(description=description)
+ ephemeral = True
+
+ await interaction.response.send_message(embed=embed, ephemeral=ephemeral)
+ if not ephemeral:
await wait_for_deletion(
- await ctx.send(embed=embed),
- (ctx.author.id,)
+ await interaction.original_response(),
+ (interaction.user.id,)
)
+
# A valid tag was found and was either sent, or is on cooldown
return True
+ @get_command.autocomplete("name")
+ async def name_autocomplete(
+ self,
+ interaction: Interaction,
+ current: str
+ ) -> list[app_commands.Choice[str]]:
+ """Autocompleter for `/tag get` command."""
+ names = [tag.name for tag in self.tags.keys()]
+ choices = [
+ app_commands.Choice(name=tag, value=tag)
+ for tag in names if current.lower() in tag
+ ]
+ return choices[:25] if len(choices) > 25 else choices
+
async def setup(bot: Bot) -> None:
"""Load the Tags cog."""
diff --git a/bot/pagination.py b/bot/pagination.py
index 0ef5808cc..c39ce211b 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -190,8 +190,8 @@ class LinePaginator(Paginator):
@classmethod
async def paginate(
cls,
- lines: t.List[str],
- ctx: Context,
+ lines: list[str],
+ ctx: Context | discord.Interaction,
embed: discord.Embed,
prefix: str = "",
suffix: str = "",
@@ -228,7 +228,10 @@ class LinePaginator(Paginator):
current_page = 0
if not restrict_to_user:
- restrict_to_user = ctx.author
+ if isinstance(ctx, discord.Interaction):
+ restrict_to_user = ctx.user
+ else:
+ restrict_to_user = ctx.author
if not lines:
if exception_on_empty_embed:
@@ -261,6 +264,8 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
+ if isinstance(ctx, discord.Interaction):
+ return await ctx.response.send_message(embed=embed)
return await ctx.send(embed=embed)
else:
if footer_text:
@@ -274,7 +279,11 @@ class LinePaginator(Paginator):
log.trace(f"Setting embed url to '{url}'")
log.debug("Sending first page to channel...")
- message = await ctx.send(embed=embed)
+ if isinstance(ctx, discord.Interaction):
+ await ctx.response.send_message(embed=embed)
+ message = await ctx.original_response()
+ else:
+ message = await ctx.send(embed=embed)
log.debug("Adding emoji reactions to message...")
@@ -292,7 +301,10 @@ class LinePaginator(Paginator):
while True:
try:
- reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
+ if isinstance(ctx, discord.Interaction):
+ reaction, user = await ctx.client.wait_for("reaction_add", timeout=timeout, check=check)
+ else:
+ reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
log.trace(f"Got reaction: {reaction}")
except asyncio.TimeoutError:
log.debug("Timed out waiting for a reaction")
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index 27f2eac97..f6bdceaef 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -58,7 +58,7 @@ def reaction_check(
async def wait_for_deletion(
- message: discord.Message,
+ message: discord.Message | discord.InteractionMessage,
user_ids: Sequence[int],
deletion_emojis: Sequence[str] = (Emojis.trashcan,),
timeout: float = 60 * 5,
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 092de0556..0ba2fcf11 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -334,13 +334,13 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
self.ctx = MockContext()
self.tag = Tags(self.bot)
self.cog = error_handler.ErrorHandler(self.bot)
- self.bot.get_command.return_value = self.tag.get_command
+ self.bot.get_cog.return_value = self.tag
async def test_try_get_tag_get_command(self):
"""Should call `Bot.get_command` with `tags get` argument."""
- self.bot.get_command.reset_mock()
+ self.bot.get_cog.reset_mock()
await self.cog.try_get_tag(self.ctx)
- self.bot.get_command.assert_called_once_with("tags get")
+ self.bot.get_cog.assert_called_once_with("Tags")
async def test_try_get_tag_invoked_from_error_handler(self):
"""`self.ctx` should have `invoked_from_error_handler` `True`."""
@@ -350,14 +350,14 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_try_get_tag_no_permissions(self):
"""Test how to handle checks failing."""
- self.tag.get_command.can_run = AsyncMock(return_value=False)
+ self.bot.can_run = AsyncMock(return_value=False)
self.ctx.invoked_with = "foo"
self.assertIsNone(await self.cog.try_get_tag(self.ctx))
async def test_try_get_tag_command_error(self):
"""Should call `on_command_error` when `CommandError` raised."""
err = errors.CommandError()
- self.tag.get_command.can_run = AsyncMock(side_effect=err)
+ self.bot.can_run = AsyncMock(side_effect=err)
self.cog.on_command_error = AsyncMock()
self.assertIsNone(await self.cog.try_get_tag(self.ctx))
self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
@@ -365,7 +365,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_dont_call_suggestion_tag_sent(self):
"""Should never call command suggestion if tag is already sent."""
self.ctx.message = MagicMock(content="foo")
- self.ctx.invoke = AsyncMock(return_value=True)
+ self.tag.get_command_ctx = AsyncMock(return_value=True)
self.cog.send_command_suggestion = AsyncMock()
await self.cog.try_get_tag(self.ctx)
@@ -385,7 +385,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_call_suggestion(self):
"""Should call command suggestion if user is not a mod."""
self.ctx.invoked_with = "foo"
- self.ctx.invoke = AsyncMock(return_value=False)
+ self.tag.get_command_ctx = AsyncMock(return_value=False)
self.cog.send_command_suggestion = AsyncMock()
await self.cog.try_get_tag(self.ctx)
diff --git a/tests/helpers.py b/tests/helpers.py
index 4b980ac21..0d955b521 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -479,6 +479,25 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)
+class MockInteraction(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Interaction objects.
+
+ Instances of this class will follow the specifications of `discord.Interaction`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.me = kwargs.get('me', MockMember())
+ self.client = kwargs.get('client', MockBot())
+ self.guild = kwargs.get('guild', MockGuild())
+ self.user = kwargs.get('user', MockMember())
+ self.channel = kwargs.get('channel', MockTextChannel())
+ self.message = kwargs.get('message', MockMessage())
+ self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)
+
+
attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())
@@ -530,6 +549,16 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
self.channel = kwargs.get('channel', MockTextChannel())
+class MockInteractionMessage(MockMessage):
+ """
+ A MagicMock subclass to mock InteractionMessage objects.
+
+ Instances of this class will follow the specifications of `discord.InteractionMessage` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ pass
+
+
emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'}
emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data)