diff options
author | 2023-02-26 17:44:47 +0530 | |
---|---|---|
committer | 2023-02-26 17:44:47 +0530 | |
commit | 03614c313341497e61c45bbb2a364b969d2bb163 (patch) | |
tree | b17bea05899f2ba1206b5e31603c5142206f40aa | |
parent | Upadte docstring for `ErrorHandler()._can_run` (diff) |
Implement reviews
+ used both `discord.User` and `discord.Member` in typehinting as `InteractionResponse.user` returns `discord.User` object
+ removed `ErrorHandler()._can_run`
+ edited `try_get_tag` to use `bot.can_run`
+ removed `/tag list`
+ change `/tag get <name>` to `/tag <name>`
+ remove redundant `GUILD_ID` in `tags.py`
+ using `discord.abc.Messageable` because `ctx.channel` returns that instead of `Channel` Object
-rw-r--r-- | bot/exts/backend/error_handler.py | 50 | ||||
-rw-r--r-- | bot/exts/info/tags.py | 36 | ||||
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 10 |
3 files changed, 32 insertions, 64 deletions
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 839d882de..e274e337a 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,8 +1,7 @@ import copy import difflib -import typing as t -from discord import Embed, Interaction, utils +from discord import Embed, Member from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -22,22 +21,6 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot - @staticmethod - async def _can_run(ctx: Context) -> bool: - """ - Add checks for the `get_command_ctx` function here. - - The command code style is copied from discord.ext.commands.Command.can_run itself. - Append checks in the checks list. - """ - checks = [] - predicates = checks - if not predicates: - # Since we have no checks, then we just return True. - return True - - return await utils.async_all(predicate(ctx) for predicate in predicates) - def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -176,7 +159,7 @@ class ErrorHandler(Cog): return True return False - async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None: + async def try_get_tag(self, ctx: Context) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -189,25 +172,28 @@ class ErrorHandler(Cog): log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.") return tags_get_command = tags_cog.get_command_ctx - can_run = can_run if can_run else self._can_run - ctx.invoked_from_error_handler = True + maybe_tag_name = ctx.invoked_with + if not maybe_tag_name or not isinstance(ctx.author, Member): + return - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + ctx.invoked_from_error_handler = True try: - if not await can_run(ctx): - log.debug(log_msg) + if not await self.bot.can_run(ctx): + log.debug("Cancelling attempt to fall back to a tag due to failed checks.") return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - if await tags_get_command(ctx, ctx.message.content): - return + if await tags_get_command(ctx, maybe_tag_name): + return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): + await self.send_command_suggestion(ctx, maybe_tag_name) + except Exception as err: + log.debug("Error while attempting to invoke tag fallback.") + if isinstance(err, errors.CommandError): + await self.on_command_error(ctx, err) + else: + await self.on_command_error(ctx, errors.CommandInvokeError(err)) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 60f730586..0c244ff37 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,8 +8,8 @@ from typing import Literal, NamedTuple, Optional, Union import discord import frontmatter -from discord import Embed, Interaction, Member, app_commands -from discord.ext.commands import Cog +from discord import Embed, Interaction, Member, User, app_commands +from discord.ext.commands import Cog, Context from bot import constants from bot.bot import Bot @@ -27,8 +27,6 @@ TEST_CHANNELS = ( REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>." -GUILD_ID = constants.Guild.id - class COOLDOWN(enum.Enum): """Sentinel value to signal that a tag is on cooldown.""" @@ -93,7 +91,7 @@ class Tag: embed.description = self.content return embed - def accessible_by(self, member: discord.Member) -> bool: + def accessible_by(self, member: Member | User) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to @@ -141,8 +139,6 @@ class Tags(Cog): self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() - tag_group = app_commands.Group(name="tag", description="...") - def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" base_path = Path("bot", "resources", "tags") @@ -188,8 +184,8 @@ class Tags(Cog): async def get_tag_embed( self, - author: discord.Member, - channel: discord.TextChannel | discord.Thread, + author: Member | User, + channel: discord.abc.Messageable, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -244,7 +240,7 @@ class Tags(Cog): description=suggested_tags_text ) - def accessible_tags(self, user: Member) -> list[str]: + def accessible_tags(self, user: Member | User) -> list[str]: """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted.""" def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: group, name = tag_item[0] @@ -278,7 +274,7 @@ class Tags(Cog): return result_lines - def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: Member | User) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" @@ -288,7 +284,7 @@ class Tags(Cog): async def get_command_ctx( self, - ctx: discord.Context, + ctx: Context, name: str ) -> bool: """Made specifically for `error_handler.py`, See `get_command` for more info.""" @@ -315,7 +311,7 @@ class Tags(Cog): # A valid tag was found and was either sent, or is on cooldown return True - @tag_group.command(name="get") + @app_commands.command(name="tag") async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group @@ -382,20 +378,6 @@ class Tags(Cog): ] return choices[:25] if len(choices) > 25 else choices - @tag_group.command(name="list") - async def list_command(self, interaction: Interaction) -> bool: - """Lists all accessible tags.""" - if self.tags: - await LinePaginator.paginate( - self.accessible_tags(interaction.user), - interaction, - Embed(title="Available tags"), - **self.PAGINATOR_DEFAULTS, - ) - else: - await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) - return True - async def setup(bot: Bot) -> None: """Load the Tags cog.""" diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 14e7a4125..533eaeda6 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -350,16 +350,16 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" - self.tag.get_command.can_run = AsyncMock(return_value=False) + self.bot.can_run = AsyncMock(return_value=False) self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() - self.tag.get_command.can_run = AsyncMock(side_effect=err) + self.bot.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) async def test_dont_call_suggestion_tag_sent(self): @@ -385,7 +385,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.tag.get_command_ctx = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() await self.cog.try_get_tag(self.ctx) |