diff options
-rw-r--r-- | bot/bot.py | 8 | ||||
-rw-r--r-- | bot/exts/avatar_modification/avatar_modify.py | 3 | ||||
-rw-r--r-- | bot/exts/core/extensions.py | 22 | ||||
-rw-r--r-- | bot/exts/core/internal_eval/_internal_eval.py | 3 | ||||
-rw-r--r-- | bot/exts/events/advent_of_code/_cog.py | 3 | ||||
-rw-r--r-- | bot/exts/fun/game.py | 3 | ||||
-rw-r--r-- | bot/exts/fun/minesweeper.py | 8 | ||||
-rw-r--r-- | bot/exts/fun/movie.py | 4 | ||||
-rw-r--r-- | bot/exts/fun/snakes/_snakes_cog.py | 3 | ||||
-rw-r--r-- | bot/exts/fun/space.py | 4 | ||||
-rw-r--r-- | bot/exts/utilities/colour.py | 3 | ||||
-rw-r--r-- | bot/exts/utilities/emoji.py | 8 | ||||
-rw-r--r-- | bot/exts/utilities/epoch.py | 8 | ||||
-rw-r--r-- | bot/exts/utilities/githubinfo.py | 3 | ||||
-rw-r--r-- | bot/exts/utilities/reddit.py | 3 | ||||
-rw-r--r-- | bot/exts/utilities/twemoji.py | 3 | ||||
-rw-r--r-- | bot/utils/extensions.py | 45 |
17 files changed, 46 insertions, 88 deletions
@@ -60,3 +60,11 @@ class Bot(BotBase): # This is not awaited to avoid a deadlock with any cogs that have # wait_until_guild_available in their cog_load method. scheduling.create_task(self.load_extensions(exts)) + + async def invoke_help_command(self, ctx: commands.Context) -> None: + """Invoke the help command or default help command if help extensions is not loaded.""" + if "bot.exts.core.help" in ctx.bot.extensions: + help_command = ctx.bot.get_command("help") + await ctx.invoke(help_command, ctx.command.qualified_name) + return + await ctx.send_help(ctx.command) diff --git a/bot/exts/avatar_modification/avatar_modify.py b/bot/exts/avatar_modification/avatar_modify.py index 3ee70cfd..337f510c 100644 --- a/bot/exts/avatar_modification/avatar_modify.py +++ b/bot/exts/avatar_modification/avatar_modify.py @@ -14,7 +14,6 @@ from discord.ext import commands from bot.bot import Bot from bot.constants import Colours, Emojis from bot.exts.avatar_modification._effects import PfpEffects -from bot.utils.extensions import invoke_help_command from bot.utils.halloween import spookifications log = logging.getLogger(__name__) @@ -89,7 +88,7 @@ class AvatarModify(commands.Cog): async def avatar_modify(self, ctx: commands.Context) -> None: """Groups all of the pfp modifying commands to allow a single concurrency limit.""" if not ctx.invoked_subcommand: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @avatar_modify.command(name="8bitify", root_aliases=("8bitify",)) async def eightbit_command(self, ctx: commands.Context) -> None: diff --git a/bot/exts/core/extensions.py b/bot/exts/core/extensions.py index d809d2b9..586222c8 100644 --- a/bot/exts/core/extensions.py +++ b/bot/exts/core/extensions.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from enum import Enum from typing import Optional +from botcore.utils._extensions import unqualify from discord import Colour, Embed from discord.ext import commands from discord.ext.commands import Context, group @@ -12,7 +13,6 @@ from bot import exts from bot.bot import Bot from bot.constants import Client, Emojis, MODERATION_ROLES, Roles from bot.utils.checks import with_role_check -from bot.utils.extensions import EXTENSIONS, invoke_help_command, unqualify from bot.utils.pagination import LinePaginator log = logging.getLogger(__name__) @@ -46,13 +46,13 @@ class Extension(commands.Converter): argument = argument.lower() - if argument in EXTENSIONS: + if argument in ctx.bot.all_extensions: return argument - elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: + elif (qualified_arg := f"{exts.__name__}.{argument}") in ctx.bot.all_extensions: return qualified_arg matches = [] - for ext in EXTENSIONS: + for ext in ctx.bot.all_extensions: if argument == unqualify(ext): matches.append(ext) @@ -78,7 +78,7 @@ class Extensions(commands.Cog): @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) async def extensions_group(self, ctx: Context) -> None: """Load, unload, reload, and list loaded extensions.""" - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @extensions_group.command(name="load", aliases=("l",)) async def load_command(self, ctx: Context, *extensions: Extension) -> None: @@ -88,11 +88,11 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. """ # noqa: W605 if not extensions: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return if "*" in extensions or "**" in extensions: - extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + extensions = set(self.bot.all_extensions) - set(self.bot.extensions.keys()) msg = self.batch_manage(Action.LOAD, *extensions) await ctx.send(msg) @@ -105,7 +105,7 @@ class Extensions(commands.Cog): If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. """ # noqa: W605 if not extensions: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) @@ -131,11 +131,11 @@ class Extensions(commands.Cog): If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. """ # noqa: W605 if not extensions: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return if "**" in extensions: - extensions = EXTENSIONS + extensions = self.bot.all_extensions elif "*" in extensions: extensions = set(self.bot.extensions.keys()) | set(extensions) extensions.remove("*") @@ -175,7 +175,7 @@ class Extensions(commands.Cog): """Return a mapping of extension names and statuses to their categories.""" categories = {} - for ext in EXTENSIONS: + for ext in self.bot.all_extensions: if ext in self.bot.extensions: status = Emojis.status_online else: diff --git a/bot/exts/core/internal_eval/_internal_eval.py b/bot/exts/core/internal_eval/_internal_eval.py index 190a15ec..2daf8ef9 100644 --- a/bot/exts/core/internal_eval/_internal_eval.py +++ b/bot/exts/core/internal_eval/_internal_eval.py @@ -9,7 +9,6 @@ from discord.ext import commands from bot.bot import Bot from bot.constants import Client, Roles from bot.utils.decorators import with_role -from bot.utils.extensions import invoke_help_command from ._helpers import EvalContext @@ -154,7 +153,7 @@ class InternalEval(commands.Cog): async def internal_group(self, ctx: commands.Context) -> None: """Internal commands. Top secret!""" if not ctx.invoked_subcommand: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @internal_group.command(name="eval", aliases=("e",)) @with_role(Roles.admins) diff --git a/bot/exts/events/advent_of_code/_cog.py b/bot/exts/events/advent_of_code/_cog.py index 1d8b0ca7..ab5a7a34 100644 --- a/bot/exts/events/advent_of_code/_cog.py +++ b/bot/exts/events/advent_of_code/_cog.py @@ -18,7 +18,6 @@ from bot.exts.events.advent_of_code.views.dayandstarview import AoCDropdownView from bot.utils import members from bot.utils.decorators import InChannelCheckFailure, in_month, whitelist_override, with_role from bot.utils.exceptions import MovedCommandError -from bot.utils.extensions import invoke_help_command log = logging.getLogger(__name__) @@ -122,7 +121,7 @@ class AdventOfCode(commands.Cog): async def adventofcode_group(self, ctx: commands.Context) -> None: """All of the Advent of Code commands.""" if not ctx.invoked_subcommand: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @with_role(Roles.admins) @adventofcode_group.command( diff --git a/bot/exts/fun/game.py b/bot/exts/fun/game.py index 5f56bef7..4730d5b3 100644 --- a/bot/exts/fun/game.py +++ b/bot/exts/fun/game.py @@ -15,7 +15,6 @@ from discord.ext.commands import Cog, Context, group from bot.bot import Bot from bot.constants import STAFF_ROLES, Tokens from bot.utils.decorators import with_role -from bot.utils.extensions import invoke_help_command from bot.utils.pagination import ImagePaginator, LinePaginator # Base URL of IGDB API @@ -267,7 +266,7 @@ class Games(Cog): """ # When user didn't specified genre, send help message if genre is None: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return # Capitalize genre for check diff --git a/bot/exts/fun/minesweeper.py b/bot/exts/fun/minesweeper.py index a48b5051..782fb9d8 100644 --- a/bot/exts/fun/minesweeper.py +++ b/bot/exts/fun/minesweeper.py @@ -11,7 +11,6 @@ from bot.bot import Bot from bot.constants import Client from bot.utils.converters import CoordinateConverter from bot.utils.exceptions import UserNotPlayingError -from bot.utils.extensions import invoke_help_command MESSAGE_MAPPING = { 0: ":stop_button:", @@ -51,13 +50,14 @@ class Game: class Minesweeper(commands.Cog): """Play a game of Minesweeper.""" - def __init__(self): + def __init__(self, bot: Bot): + self.bot = bot self.games: dict[int, Game] = {} @commands.group(name="minesweeper", aliases=("ms",), invoke_without_command=True) async def minesweeper_group(self, ctx: commands.Context) -> None: """Commands for Playing Minesweeper.""" - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @staticmethod def get_neighbours(x: int, y: int) -> Iterator[tuple[int, int]]: @@ -267,4 +267,4 @@ class Minesweeper(commands.Cog): def setup(bot: Bot) -> None: """Load the Minesweeper cog.""" - bot.add_cog(Minesweeper()) + bot.add_cog(Minesweeper(bot)) diff --git a/bot/exts/fun/movie.py b/bot/exts/fun/movie.py index a04eeb41..4418b938 100644 --- a/bot/exts/fun/movie.py +++ b/bot/exts/fun/movie.py @@ -9,7 +9,6 @@ from discord.ext.commands import Cog, Context, group from bot.bot import Bot from bot.constants import Tokens -from bot.utils.extensions import invoke_help_command from bot.utils.pagination import ImagePaginator # Define base URL of TMDB @@ -50,6 +49,7 @@ class Movie(Cog): """Movie Cog contains movies command that grab random movies from TMDB.""" def __init__(self, bot: Bot): + self.bot = bot self.http_session: ClientSession = bot.http_session @group(name="movies", aliases=("movie",), invoke_without_command=True) @@ -73,7 +73,7 @@ class Movie(Cog): try: result = await self.get_movies_data(self.http_session, MovieGenres[genre].value, 1) except KeyError: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return # Check if "results" is in result. If not, throw error. diff --git a/bot/exts/fun/snakes/_snakes_cog.py b/bot/exts/fun/snakes/_snakes_cog.py index 59e57199..96718200 100644 --- a/bot/exts/fun/snakes/_snakes_cog.py +++ b/bot/exts/fun/snakes/_snakes_cog.py @@ -22,7 +22,6 @@ from bot.constants import ERROR_REPLIES, Tokens from bot.exts.fun.snakes import _utils as utils from bot.exts.fun.snakes._converter import Snake from bot.utils.decorators import locked -from bot.utils.extensions import invoke_help_command log = logging.getLogger(__name__) @@ -440,7 +439,7 @@ class Snakes(Cog): @group(name="snakes", aliases=("snake",), invoke_without_command=True) async def snakes_group(self, ctx: Context) -> None: """Commands from our first code jam.""" - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @bot_has_permissions(manage_messages=True) @snakes_group.command(name="antidote") diff --git a/bot/exts/fun/space.py b/bot/exts/fun/space.py index 48ad0f96..0bbe0b33 100644 --- a/bot/exts/fun/space.py +++ b/bot/exts/fun/space.py @@ -11,7 +11,6 @@ from discord.ext.commands import Cog, Context, group from bot.bot import Bot from bot.constants import Tokens from bot.utils.converters import DateConverter -from bot.utils.extensions import invoke_help_command logger = logging.getLogger(__name__) @@ -27,6 +26,7 @@ class Space(Cog): def __init__(self, bot: Bot): self.http_session = bot.http_session + self.bot = bot self.rovers = {} self.get_rovers.start() @@ -50,7 +50,7 @@ class Space(Cog): @group(name="space", invoke_without_command=True) async def space(self, ctx: Context) -> None: """Head command that contains commands about space.""" - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @space.command(name="apod") async def apod(self, ctx: Context, date: Optional[str]) -> None: diff --git a/bot/exts/utilities/colour.py b/bot/exts/utilities/colour.py index ee6bad93..5282bc6d 100644 --- a/bot/exts/utilities/colour.py +++ b/bot/exts/utilities/colour.py @@ -13,7 +13,6 @@ from discord.ext import commands from bot import constants from bot.bot import Bot -from bot.exts.core.extensions import invoke_help_command from bot.utils.decorators import whitelist_override THUMBNAIL_SIZE = (80, 80) @@ -99,7 +98,7 @@ class Colour(commands.Cog): extra_colour = ImageColor.getrgb(colour_input) await self.send_colour_response(ctx, extra_colour) except ValueError: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @colour.command() async def rgb(self, ctx: commands.Context, red: int, green: int, blue: int) -> None: diff --git a/bot/exts/utilities/emoji.py b/bot/exts/utilities/emoji.py index fa438d7f..2b2fab8a 100644 --- a/bot/exts/utilities/emoji.py +++ b/bot/exts/utilities/emoji.py @@ -10,7 +10,6 @@ from discord.ext import commands from bot.bot import Bot from bot.constants import Colours, ERROR_REPLIES -from bot.utils.extensions import invoke_help_command from bot.utils.pagination import LinePaginator from bot.utils.time import time_since @@ -20,6 +19,9 @@ log = logging.getLogger(__name__) class Emojis(commands.Cog): """A collection of commands related to emojis in the server.""" + def __init__(self, bot: Bot) -> None: + self.bot = bot + @staticmethod def embed_builder(emoji: dict) -> tuple[Embed, list[str]]: """Generates an embed with the emoji names and count.""" @@ -74,7 +76,7 @@ class Emojis(commands.Cog): if emoji is not None: await ctx.invoke(self.info_command, emoji) else: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @emoji_group.command(name="count", aliases=("c",)) async def count_command(self, ctx: commands.Context, *, category_query: str = None) -> None: @@ -120,4 +122,4 @@ class Emojis(commands.Cog): def setup(bot: Bot) -> None: """Load the Emojis cog.""" - bot.add_cog(Emojis()) + bot.add_cog(Emojis(bot)) diff --git a/bot/exts/utilities/epoch.py b/bot/exts/utilities/epoch.py index 42312dd1..2a21688e 100644 --- a/bot/exts/utilities/epoch.py +++ b/bot/exts/utilities/epoch.py @@ -6,7 +6,6 @@ from dateutil import parser from discord.ext import commands from bot.bot import Bot -from bot.utils.extensions import invoke_help_command # https://discord.com/developers/docs/reference#message-formatting-timestamp-styles STYLES = { @@ -48,6 +47,9 @@ class DateString(commands.Converter): class Epoch(commands.Cog): """Convert an entered time and date to a unix timestamp.""" + def __init__(self, bot: Bot) -> None: + self.bot = bot + @commands.command(name="epoch") async def epoch(self, ctx: commands.Context, *, date_time: DateString = None) -> None: """ @@ -71,7 +73,7 @@ class Epoch(commands.Cog): Times in the dropdown are shown in UTC """ if not date_time: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return if isinstance(date_time, tuple): @@ -135,4 +137,4 @@ class TimestampMenuView(discord.ui.View): def setup(bot: Bot) -> None: """Load the Epoch cog.""" - bot.add_cog(Epoch()) + bot.add_cog(Epoch(bot)) diff --git a/bot/exts/utilities/githubinfo.py b/bot/exts/utilities/githubinfo.py index 046f67df..ed176290 100644 --- a/bot/exts/utilities/githubinfo.py +++ b/bot/exts/utilities/githubinfo.py @@ -12,7 +12,6 @@ from discord.ext import commands from bot.bot import Bot from bot.constants import Colours, ERROR_REPLIES, Emojis, NEGATIVE_REPLIES, Tokens -from bot.exts.core.extensions import invoke_help_command log = logging.getLogger(__name__) @@ -168,7 +167,7 @@ class GithubInfo(commands.Cog): async def github_group(self, ctx: commands.Context) -> None: """Commands for finding information related to GitHub.""" if ctx.invoked_subcommand is None: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @commands.Cog.listener() async def on_message(self, message: discord.Message) -> None: diff --git a/bot/exts/utilities/reddit.py b/bot/exts/utilities/reddit.py index 782583d2..fa50eb36 100644 --- a/bot/exts/utilities/reddit.py +++ b/bot/exts/utilities/reddit.py @@ -15,7 +15,6 @@ from discord.utils import escape_markdown, sleep_until from bot.bot import Bot from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES from bot.utils.converters import Subreddit -from bot.utils.extensions import invoke_help_command from bot.utils.messages import sub_clyde from bot.utils.pagination import ImagePaginator, LinePaginator @@ -302,7 +301,7 @@ class Reddit(Cog): @group(name="reddit", invoke_without_command=True) async def reddit_group(self, ctx: Context) -> None: """View the top posts from various subreddits.""" - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) @reddit_group.command(name="top") async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None: diff --git a/bot/exts/utilities/twemoji.py b/bot/exts/utilities/twemoji.py index c915f05b..a4477bc1 100644 --- a/bot/exts/utilities/twemoji.py +++ b/bot/exts/utilities/twemoji.py @@ -9,7 +9,6 @@ from emoji import UNICODE_EMOJI_ENGLISH, is_emoji from bot.bot import Bot from bot.constants import Colours, Roles from bot.utils.decorators import whitelist_override -from bot.utils.extensions import invoke_help_command log = logging.getLogger(__name__) BASE_URLS = { @@ -133,7 +132,7 @@ class Twemoji(commands.Cog): async def twemoji(self, ctx: commands.Context, *raw_emoji: str) -> None: """Sends a preview of a given Twemoji, specified by codepoint or emoji.""" if len(raw_emoji) == 0: - await invoke_help_command(ctx) + await self.bot.invoke_help_command(ctx) return try: codepoint = self.codepoint_from_input(raw_emoji) diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py deleted file mode 100644 index 09192ae2..00000000 --- a/bot/utils/extensions.py +++ /dev/null @@ -1,45 +0,0 @@ -import importlib -import inspect -import pkgutil -from collections.abc import Iterator -from typing import NoReturn - -from discord.ext.commands import Context - -from bot import exts - - -def unqualify(name: str) -> str: - """Return an unqualified name given a qualified module/package `name`.""" - return name.rsplit(".", maxsplit=1)[-1] - - -def walk_extensions() -> Iterator[str]: - """Yield extension names from the bot.exts subpackage.""" - - def on_error(name: str) -> NoReturn: - raise ImportError(name=name) # pragma: no cover - - for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): - if unqualify(module.name).startswith("_"): - # Ignore module/package names starting with an underscore. - continue - - if module.ispkg: - imported = importlib.import_module(module.name) - if not inspect.isfunction(getattr(imported, "setup", None)): - # If it lacks a setup function, it's not an extension. - continue - - yield module.name - - -async def invoke_help_command(ctx: Context) -> None: - """Invoke the help command or default help command if help extensions is not loaded.""" - if "bot.exts.core.help" in ctx.bot.extensions: - help_command = ctx.bot.get_command("help") - await ctx.invoke(help_command, ctx.command.qualified_name) - return - await ctx.send_help(ctx.command) - -EXTENSIONS = frozenset(walk_extensions()) |