aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/backend/error_handler.py50
-rw-r--r--bot/exts/info/tags.py36
-rw-r--r--tests/bot/exts/backend/test_error_handler.py10
3 files changed, 32 insertions, 64 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index 839d882de..e274e337a 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -1,8 +1,7 @@
import copy
import difflib
-import typing as t
-from discord import Embed, Interaction, utils
+from discord import Embed, Member
from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors
from pydis_core.site_api import ResponseCodeError
from sentry_sdk import push_scope
@@ -22,22 +21,6 @@ class ErrorHandler(Cog):
def __init__(self, bot: Bot):
self.bot = bot
- @staticmethod
- async def _can_run(ctx: Context) -> bool:
- """
- Add checks for the `get_command_ctx` function here.
-
- The command code style is copied from discord.ext.commands.Command.can_run itself.
- Append checks in the checks list.
- """
- checks = []
- predicates = checks
- if not predicates:
- # Since we have no checks, then we just return True.
- return True
-
- return await utils.async_all(predicate(ctx) for predicate in predicates)
-
def _get_error_embed(self, title: str, body: str) -> Embed:
"""Return an embed that contains the exception."""
return Embed(
@@ -176,7 +159,7 @@ class ErrorHandler(Cog):
return True
return False
- async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None:
+ async def try_get_tag(self, ctx: Context) -> None:
"""
Attempt to display a tag by interpreting the command name as a tag name.
@@ -189,25 +172,28 @@ class ErrorHandler(Cog):
log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.")
return
tags_get_command = tags_cog.get_command_ctx
- can_run = can_run if can_run else self._can_run
- ctx.invoked_from_error_handler = True
+ maybe_tag_name = ctx.invoked_with
+ if not maybe_tag_name or not isinstance(ctx.author, Member):
+ return
- log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
+ ctx.invoked_from_error_handler = True
try:
- if not await can_run(ctx):
- log.debug(log_msg)
+ if not await self.bot.can_run(ctx):
+ log.debug("Cancelling attempt to fall back to a tag due to failed checks.")
return
- except errors.CommandError as tag_error:
- log.debug(log_msg)
- await self.on_command_error(ctx, tag_error)
- return
- if await tags_get_command(ctx, ctx.message.content):
- return
+ if await tags_get_command(ctx, maybe_tag_name):
+ return
- if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
- await self.send_command_suggestion(ctx, ctx.invoked_with)
+ if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
+ await self.send_command_suggestion(ctx, maybe_tag_name)
+ except Exception as err:
+ log.debug("Error while attempting to invoke tag fallback.")
+ if isinstance(err, errors.CommandError):
+ await self.on_command_error(ctx, err)
+ else:
+ await self.on_command_error(ctx, errors.CommandInvokeError(err))
async def try_run_eval(self, ctx: Context) -> bool:
"""
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index 60f730586..0c244ff37 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -8,8 +8,8 @@ from typing import Literal, NamedTuple, Optional, Union
import discord
import frontmatter
-from discord import Embed, Interaction, Member, app_commands
-from discord.ext.commands import Cog
+from discord import Embed, Interaction, Member, User, app_commands
+from discord.ext.commands import Cog, Context
from bot import constants
from bot.bot import Bot
@@ -27,8 +27,6 @@ TEST_CHANNELS = (
REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE)
FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>."
-GUILD_ID = constants.Guild.id
-
class COOLDOWN(enum.Enum):
"""Sentinel value to signal that a tag is on cooldown."""
@@ -93,7 +91,7 @@ class Tag:
embed.description = self.content
return embed
- def accessible_by(self, member: discord.Member) -> bool:
+ def accessible_by(self, member: Member | User) -> bool:
"""Check whether `member` can access the tag."""
return bool(
not self._restricted_to
@@ -141,8 +139,6 @@ class Tags(Cog):
self.tags: dict[TagIdentifier, Tag] = {}
self.initialize_tags()
- tag_group = app_commands.Group(name="tag", description="...")
-
def initialize_tags(self) -> None:
"""Load all tags from resources into `self.tags`."""
base_path = Path("bot", "resources", "tags")
@@ -188,8 +184,8 @@ class Tags(Cog):
async def get_tag_embed(
self,
- author: discord.Member,
- channel: discord.TextChannel | discord.Thread,
+ author: Member | User,
+ channel: discord.abc.Messageable,
tag_identifier: TagIdentifier,
) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:
"""
@@ -244,7 +240,7 @@ class Tags(Cog):
description=suggested_tags_text
)
- def accessible_tags(self, user: Member) -> list[str]:
+ def accessible_tags(self, user: Member | User) -> list[str]:
"""Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted."""
def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str:
group, name = tag_item[0]
@@ -278,7 +274,7 @@ class Tags(Cog):
return result_lines
- def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]:
+ def accessible_tags_in_group(self, group: str, user: Member | User) -> list[str]:
"""Return a formatted list of tags in `group`, that are accessible by `user`."""
return sorted(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
@@ -288,7 +284,7 @@ class Tags(Cog):
async def get_command_ctx(
self,
- ctx: discord.Context,
+ ctx: Context,
name: str
) -> bool:
"""Made specifically for `error_handler.py`, See `get_command` for more info."""
@@ -315,7 +311,7 @@ class Tags(Cog):
# A valid tag was found and was either sent, or is on cooldown
return True
- @tag_group.command(name="get")
+ @app_commands.command(name="tag")
async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool:
"""
If a single argument matching a group name is given, list all accessible tags from that group
@@ -382,20 +378,6 @@ class Tags(Cog):
]
return choices[:25] if len(choices) > 25 else choices
- @tag_group.command(name="list")
- async def list_command(self, interaction: Interaction) -> bool:
- """Lists all accessible tags."""
- if self.tags:
- await LinePaginator.paginate(
- self.accessible_tags(interaction.user),
- interaction,
- Embed(title="Available tags"),
- **self.PAGINATOR_DEFAULTS,
- )
- else:
- await interaction.response.send_message(embed=Embed(description="**There are no tags!**"))
- return True
-
async def setup(bot: Bot) -> None:
"""Load the Tags cog."""
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 14e7a4125..533eaeda6 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -350,16 +350,16 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_try_get_tag_no_permissions(self):
"""Test how to handle checks failing."""
- self.tag.get_command.can_run = AsyncMock(return_value=False)
+ self.bot.can_run = AsyncMock(return_value=False)
self.ctx.invoked_with = "foo"
- self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False)))
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
async def test_try_get_tag_command_error(self):
"""Should call `on_command_error` when `CommandError` raised."""
err = errors.CommandError()
- self.tag.get_command.can_run = AsyncMock(side_effect=err)
+ self.bot.can_run = AsyncMock(side_effect=err)
self.cog.on_command_error = AsyncMock()
- self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err)))
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
async def test_dont_call_suggestion_tag_sent(self):
@@ -385,7 +385,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_call_suggestion(self):
"""Should call command suggestion if user is not a mod."""
self.ctx.invoked_with = "foo"
- self.ctx.invoke = AsyncMock(return_value=False)
+ self.tag.get_command_ctx = AsyncMock(return_value=False)
self.cog.send_command_suggestion = AsyncMock()
await self.cog.try_get_tag(self.ctx)