diff options
author | 2024-04-06 15:13:34 +0800 | |
---|---|---|
committer | 2024-04-08 08:46:13 +0800 | |
commit | ff1d3a8269a17956e9617c2ac69646fc1f29a00e (patch) | |
tree | 17d4ead31afa14efad67b111527c35d27383dc64 | |
parent | Bump sentry-sdk from 1.44.0 to 1.44.1 (#3002) (diff) |
Fix showing `!source` on tags when tags cog is reloaded
Previously the source command checks for the source object's class after
getting it from the SourceType converter. When the tags cog is reloaded,
TagIdentifier is redefined. We were checking whether the old
TagIdentifier class (the `type()` of an instance of the old
TagIdentifier) equals the newly defined TagIdentifier class. This
returns false because the TagIdentifier class is redefined when the tags
cog is reloaded, which caused tags, correctly identified by Source
converter, to be unable to be identified in the source cog.
The fix takes advantage of the fact that the source converter could
correctly identify tags objects even after the tags cog reloads, and use
enum comparisons rather than `isinstance`/`type()` to obtain the source
type of whatever the source converter returns.
Since we're no longer using the source converter as an actual converter,
the function is moved to the source cog instead, and it still works
fine.
-rw-r--r-- | bot/converters.py | 36 | ||||
-rw-r--r-- | bot/exts/info/source.py | 67 |
2 files changed, 55 insertions, 48 deletions
diff --git a/bot/converters.py b/bot/converters.py index 34a764567..c04158d4d 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -10,7 +10,7 @@ import discord from aiohttp import ClientConnectorError from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Context, Converter, IDConverter, MemberConverter, UserConverter -from discord.utils import escape_markdown, snowflake_time +from discord.utils import snowflake_time from pydis_core.site_api import ResponseCodeError from pydis_core.utils import unqualify from pydis_core.utils.regex import DISCORD_INVITE @@ -19,7 +19,6 @@ from bot import exts, instance as bot_instance from bot.constants import URLs from bot.errors import InvalidInfractionError from bot.exts.info.doc import _inventory_parser -from bot.exts.info.tags import TagIdentifier from bot.log import get_logger from bot.utils import time @@ -220,39 +219,6 @@ class Snowflake(IDConverter): return snowflake -class SourceConverter(Converter): - """Convert an argument into a help command, tag, command, or cog.""" - - @staticmethod - async def convert(ctx: Context, argument: str) -> SourceType: - """Convert argument into source object.""" - if argument.lower() == "help": - return ctx.bot.help_command - - cog = ctx.bot.get_cog(argument) - if cog: - return cog - - cmd = ctx.bot.get_command(argument) - if cmd: - return cmd - - tags_cog = ctx.bot.get_cog("Tags") - show_tag = True - - if not tags_cog: - show_tag = False - else: - identifier = TagIdentifier.from_string(argument.lower()) - if identifier in tags_cog.tags: - return identifier - escaped_arg = escape_markdown(argument) - - raise BadArgument( - f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." - ) - - class DurationDelta(Converter): """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 1c3387e53..9dd5c39bf 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -1,15 +1,25 @@ +import enum import inspect from pathlib import Path from discord import Embed from discord.ext import commands +from discord.utils import escape_markdown from bot.bot import Bot from bot.constants import URLs -from bot.converters import SourceConverter from bot.exts.info.tags import TagIdentifier -SourceType = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded +SourceObject = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded + +class SourceType(enum.StrEnum): + """The types of source objects recognized by the source command.""" + + help_command = enum.auto() + command = enum.auto() + cog = enum.auto() + tag = enum.auto() + extension_not_loaded = enum.auto() class BotSource(commands.Cog): @@ -23,7 +33,7 @@ class BotSource(commands.Cog): self, ctx: commands.Context, *, - source_item: SourceConverter = None, + source_item: str | None = None, ) -> None: """Display information and a GitHub link to the source code of a command, tag, or cog.""" if not source_item: @@ -33,20 +43,51 @@ class BotSource(commands.Cog): await ctx.send(embed=embed) return - embed = await self.build_embed(source_item) + obj, source_type = await self.get_source_object(ctx, source_item) + embed = await self.build_embed(obj, source_type) await ctx.send(embed=embed) - def get_source_link(self, source_item: SourceType) -> tuple[str, str, int | None]: + @staticmethod + async def get_source_object(ctx: commands.Context, argument: str) -> tuple[SourceObject, SourceType]: + """Convert argument into the source object and source type.""" + if argument.lower() == "help": + return ctx.bot.help_command, SourceType.help_command + + cog = ctx.bot.get_cog(argument) + if cog: + return cog, SourceType.cog + + cmd = ctx.bot.get_command(argument) + if cmd: + return cmd, SourceType.command + + tags_cog = ctx.bot.get_cog("Tags") + show_tag = True + + if not tags_cog: + show_tag = False + else: + identifier = TagIdentifier.from_string(argument.lower()) + if identifier in tags_cog.tags: + return identifier, SourceType.tag + + escaped_arg = escape_markdown(argument) + + raise commands.BadArgument( + f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog." + ) + + def get_source_link(self, source_item: SourceObject, source_type: SourceType) -> tuple[str, str, int | None]: """ Build GitHub link of source item, return this link, file location and first line number. Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval). """ - if isinstance(source_item, commands.Command): + if source_type == SourceType.command: source_item = inspect.unwrap(source_item.callback) src = source_item.__code__ filename = src.co_filename - elif isinstance(source_item, TagIdentifier): + elif source_type == SourceType.tag: tags_cog = self.bot.get_cog("Tags") filename = tags_cog.tags[source_item].file_path else: @@ -56,7 +97,7 @@ class BotSource(commands.Cog): except TypeError: raise commands.BadArgument("Cannot get source for a dynamically-created object.") - if not isinstance(source_item, TagIdentifier): + if source_type != SourceType.tag: try: lines, first_line_no = inspect.getsourcelines(src) except OSError: @@ -77,17 +118,17 @@ class BotSource(commands.Cog): return url, file_location, first_line_no or None - async def build_embed(self, source_object: SourceType) -> Embed | None: + async def build_embed(self, source_object: SourceObject, source_type: SourceType) -> Embed | None: """Build embed based on source object.""" - url, location, first_line = self.get_source_link(source_object) + url, location, first_line = self.get_source_link(source_object, source_type) - if isinstance(source_object, commands.HelpCommand): + if source_type == SourceType.help_command: title = "Help Command" description = source_object.__doc__.splitlines()[1] - elif isinstance(source_object, commands.Command): + elif source_type == SourceType.command: description = source_object.short_doc title = f"Command: {source_object.qualified_name}" - elif isinstance(source_object, TagIdentifier): + elif source_type == SourceType.tag: title = f"Tag: {source_object}" description = "" else: |