From ff1d3a8269a17956e9617c2ac69646fc1f29a00e Mon Sep 17 00:00:00 2001 From: hedy Date: Sat, 6 Apr 2024 15:13:34 +0800 Subject: Fix showing `!source` on tags when tags cog is reloaded Previously the source command checks for the source object's class after getting it from the SourceType converter. When the tags cog is reloaded, TagIdentifier is redefined. We were checking whether the old TagIdentifier class (the `type()` of an instance of the old TagIdentifier) equals the newly defined TagIdentifier class. This returns false because the TagIdentifier class is redefined when the tags cog is reloaded, which caused tags, correctly identified by Source converter, to be unable to be identified in the source cog. The fix takes advantage of the fact that the source converter could correctly identify tags objects even after the tags cog reloads, and use enum comparisons rather than `isinstance`/`type()` to obtain the source type of whatever the source converter returns. Since we're no longer using the source converter as an actual converter, the function is moved to the source cog instead, and it still works fine. --- bot/converters.py | 36 +------------------------- bot/exts/info/source.py | 67 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/bot/converters.py b/bot/converters.py index 34a764567..c04158d4d 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -10,7 +10,7 @@ import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Context, Converter, IDConverter, MemberConverter, UserConverter -from discord.utils import escape_markdown, snowflake_time +from discord.utils import snowflake_time from pydis_core.site_api import ResponseCodeError from pydis_core.utils import unqualify from pydis_core.utils.regex import DISCORD_INVITE @@ -19,7 +19,6 @@ from bot import exts, instance as bot_instance from bot.constants import URLs from bot.errors import InvalidInfractionError from bot.exts.info.doc import _inventory_parser -from bot.exts.info.tags import TagIdentifier from bot.log import get_logger from bot.utils import time @@ -220,39 +219,6 @@ class Snowflake(IDConverter): return snowflake -class SourceConverter(Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - @staticmethod - async def convert(ctx: Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower() == "help": - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - else: - identifier = TagIdentifier.from_string(argument.lower()) - if identifier in tags_cog.tags: - return identifier - escaped_arg = escape_markdown(argument) - - raise BadArgument( - f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." - ) - - class DurationDelta(Converter): """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 1c3387e53..9dd5c39bf 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -1,15 +1,25 @@ +import enum import inspect from pathlib import Path from discord import Embed from discord.ext import commands +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import URLs -from bot.converters import SourceConverter from bot.exts.info.tags import TagIdentifier -SourceType = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded +SourceObject = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded + +class SourceType(enum.StrEnum): + """The types of source objects recognized by the source command.""" + + help_command = enum.auto() + command = enum.auto() + cog = enum.auto() + tag = enum.auto() + extension_not_loaded = enum.auto() class BotSource(commands.Cog): @@ -23,7 +33,7 @@ class BotSource(commands.Cog): self, ctx: commands.Context, *, - source_item: SourceConverter = None, + source_item: str | None = None, ) -> None: """Display information and a GitHub link to the source code of a command, tag, or cog.""" if not source_item: @@ -33,20 +43,51 @@ class BotSource(commands.Cog): await ctx.send(embed=embed) return - embed = await self.build_embed(source_item) + obj, source_type = await self.get_source_object(ctx, source_item) + embed = await self.build_embed(obj, source_type) await ctx.send(embed=embed) - def get_source_link(self, source_item: SourceType) -> tuple[str, str, int | None]: + @staticmethod + async def get_source_object(ctx: commands.Context, argument: str) -> tuple[SourceObject, SourceType]: + """Convert argument into the source object and source type.""" + if argument.lower() == "help": + return ctx.bot.help_command, SourceType.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog, SourceType.cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd, SourceType.command + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier, SourceType.tag + + escaped_arg = escape_markdown(argument) + + raise commands.BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) + + def get_source_link(self, source_item: SourceObject, source_type: SourceType) -> tuple[str, str, int | None]: """ Build GitHub link of source item, return this link, file location and first line number. Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). """ - if isinstance(source_item, commands.Command): + if source_type == SourceType.command: source_item = inspect.unwrap(source_item.callback) src = source_item.__code__ filename = src.co_filename - elif isinstance(source_item, TagIdentifier): + elif source_type == SourceType.tag: tags_cog = self.bot.get_cog("Tags") filename = tags_cog.tags[source_item].file_path else: @@ -56,7 +97,7 @@ class BotSource(commands.Cog): except TypeError: raise commands.BadArgument("Cannot get source for a dynamically-created object.") - if not isinstance(source_item, TagIdentifier): + if source_type != SourceType.tag: try: lines, first_line_no = inspect.getsourcelines(src) except OSError: @@ -77,17 +118,17 @@ class BotSource(commands.Cog): return url, file_location, first_line_no or None - async def build_embed(self, source_object: SourceType) -> Embed | None: + async def build_embed(self, source_object: SourceObject, source_type: SourceType) -> Embed | None: """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) + url, location, first_line = self.get_source_link(source_object, source_type) - if isinstance(source_object, commands.HelpCommand): + if source_type == SourceType.help_command: title = "Help Command" description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): + elif source_type == SourceType.command: description = source_object.short_doc title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, TagIdentifier): + elif source_type == SourceType.tag: title = f"Tag: {source_object}" description = "" else: -- cgit v1.2.3