aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar hedy <[email protected]>2024-04-06 15:13:34 +0800
committerGravatar hedy <[email protected]>2024-04-08 08:46:13 +0800
commitff1d3a8269a17956e9617c2ac69646fc1f29a00e (patch)
tree17d4ead31afa14efad67b111527c35d27383dc64
parentBump sentry-sdk from 1.44.0 to 1.44.1 (#3002) (diff)
Fix showing `!source` on tags when tags cog is reloaded
Previously the source command checks for the source object's class after getting it from the SourceType converter. When the tags cog is reloaded, TagIdentifier is redefined. We were checking whether the old TagIdentifier class (the `type()` of an instance of the old TagIdentifier) equals the newly defined TagIdentifier class. This returns false because the TagIdentifier class is redefined when the tags cog is reloaded, which caused tags, correctly identified by Source converter, to be unable to be identified in the source cog. The fix takes advantage of the fact that the source converter could correctly identify tags objects even after the tags cog reloads, and use enum comparisons rather than `isinstance`/`type()` to obtain the source type of whatever the source converter returns. Since we're no longer using the source converter as an actual converter, the function is moved to the source cog instead, and it still works fine.
-rw-r--r--bot/converters.py36
-rw-r--r--bot/exts/info/source.py67
2 files changed, 55 insertions, 48 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 34a764567..c04158d4d 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -10,7 +10,7 @@ import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
from discord.ext.commands import BadArgument, Context, Converter, IDConverter, MemberConverter, UserConverter
-from discord.utils import escape_markdown, snowflake_time
+from discord.utils import snowflake_time
from pydis_core.site_api import ResponseCodeError
from pydis_core.utils import unqualify
from pydis_core.utils.regex import DISCORD_INVITE
@@ -19,7 +19,6 @@ from bot import exts, instance as bot_instance
from bot.constants import URLs
from bot.errors import InvalidInfractionError
from bot.exts.info.doc import _inventory_parser
-from bot.exts.info.tags import TagIdentifier
from bot.log import get_logger
from bot.utils import time
@@ -220,39 +219,6 @@ class Snowflake(IDConverter):
return snowflake
-class SourceConverter(Converter):
- """Convert an argument into a help command, tag, command, or cog."""
-
- @staticmethod
- async def convert(ctx: Context, argument: str) -> SourceType:
- """Convert argument into source object."""
- if argument.lower() == "help":
- return ctx.bot.help_command
-
- cog = ctx.bot.get_cog(argument)
- if cog:
- return cog
-
- cmd = ctx.bot.get_command(argument)
- if cmd:
- return cmd
-
- tags_cog = ctx.bot.get_cog("Tags")
- show_tag = True
-
- if not tags_cog:
- show_tag = False
- else:
- identifier = TagIdentifier.from_string(argument.lower())
- if identifier in tags_cog.tags:
- return identifier
- escaped_arg = escape_markdown(argument)
-
- raise BadArgument(
- f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
- )
-
-
class DurationDelta(Converter):
"""Convert duration strings into dateutil.relativedelta.relativedelta objects."""
diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py
index 1c3387e53..9dd5c39bf 100644
--- a/bot/exts/info/source.py
+++ b/bot/exts/info/source.py
@@ -1,15 +1,25 @@
+import enum
import inspect
from pathlib import Path
from discord import Embed
from discord.ext import commands
+from discord.utils import escape_markdown
from bot.bot import Bot
from bot.constants import URLs
-from bot.converters import SourceConverter
from bot.exts.info.tags import TagIdentifier
-SourceType = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded
+SourceObject = commands.HelpCommand | commands.Command | commands.Cog | TagIdentifier | commands.ExtensionNotLoaded
+
+class SourceType(enum.StrEnum):
+ """The types of source objects recognized by the source command."""
+
+ help_command = enum.auto()
+ command = enum.auto()
+ cog = enum.auto()
+ tag = enum.auto()
+ extension_not_loaded = enum.auto()
class BotSource(commands.Cog):
@@ -23,7 +33,7 @@ class BotSource(commands.Cog):
self,
ctx: commands.Context,
*,
- source_item: SourceConverter = None,
+ source_item: str | None = None,
) -> None:
"""Display information and a GitHub link to the source code of a command, tag, or cog."""
if not source_item:
@@ -33,20 +43,51 @@ class BotSource(commands.Cog):
await ctx.send(embed=embed)
return
- embed = await self.build_embed(source_item)
+ obj, source_type = await self.get_source_object(ctx, source_item)
+ embed = await self.build_embed(obj, source_type)
await ctx.send(embed=embed)
- def get_source_link(self, source_item: SourceType) -> tuple[str, str, int | None]:
+ @staticmethod
+ async def get_source_object(ctx: commands.Context, argument: str) -> tuple[SourceObject, SourceType]:
+ """Convert argument into the source object and source type."""
+ if argument.lower() == "help":
+ return ctx.bot.help_command, SourceType.help_command
+
+ cog = ctx.bot.get_cog(argument)
+ if cog:
+ return cog, SourceType.cog
+
+ cmd = ctx.bot.get_command(argument)
+ if cmd:
+ return cmd, SourceType.command
+
+ tags_cog = ctx.bot.get_cog("Tags")
+ show_tag = True
+
+ if not tags_cog:
+ show_tag = False
+ else:
+ identifier = TagIdentifier.from_string(argument.lower())
+ if identifier in tags_cog.tags:
+ return identifier, SourceType.tag
+
+ escaped_arg = escape_markdown(argument)
+
+ raise commands.BadArgument(
+ f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
+ )
+
+ def get_source_link(self, source_item: SourceObject, source_type: SourceType) -> tuple[str, str, int | None]:
"""
Build GitHub link of source item, return this link, file location and first line number.
Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval).
"""
- if isinstance(source_item, commands.Command):
+ if source_type == SourceType.command:
source_item = inspect.unwrap(source_item.callback)
src = source_item.__code__
filename = src.co_filename
- elif isinstance(source_item, TagIdentifier):
+ elif source_type == SourceType.tag:
tags_cog = self.bot.get_cog("Tags")
filename = tags_cog.tags[source_item].file_path
else:
@@ -56,7 +97,7 @@ class BotSource(commands.Cog):
except TypeError:
raise commands.BadArgument("Cannot get source for a dynamically-created object.")
- if not isinstance(source_item, TagIdentifier):
+ if source_type != SourceType.tag:
try:
lines, first_line_no = inspect.getsourcelines(src)
except OSError:
@@ -77,17 +118,17 @@ class BotSource(commands.Cog):
return url, file_location, first_line_no or None
- async def build_embed(self, source_object: SourceType) -> Embed | None:
+ async def build_embed(self, source_object: SourceObject, source_type: SourceType) -> Embed | None:
"""Build embed based on source object."""
- url, location, first_line = self.get_source_link(source_object)
+ url, location, first_line = self.get_source_link(source_object, source_type)
- if isinstance(source_object, commands.HelpCommand):
+ if source_type == SourceType.help_command:
title = "Help Command"
description = source_object.__doc__.splitlines()[1]
- elif isinstance(source_object, commands.Command):
+ elif source_type == SourceType.command:
description = source_object.short_doc
title = f"Command: {source_object.qualified_name}"
- elif isinstance(source_object, TagIdentifier):
+ elif source_type == SourceType.tag:
title = f"Tag: {source_object}"
description = ""
else: