diff options
author | 2020-08-12 23:22:18 -0700 | |
---|---|---|
committer | 2020-08-14 09:43:39 -0700 | |
commit | 34064231d4c2afffc6b6b40d1e2f59f15d897c04 (patch) | |
tree | 6b4777fcc2385a8408f01f3e237c2ffddd3c2239 | |
parent | Fix paths used to load extensions (diff) |
Extensions: adjust discovery to work with dir structure
Discover extensions recursively and ignore any modules/packages whose
names start with an underscore.
-rw-r--r-- | bot/cogs/utils/extensions.py | 36 |
1 files changed, 28 insertions, 8 deletions
diff --git a/bot/cogs/utils/extensions.py b/bot/cogs/utils/extensions.py index 365f198ff..d01825fdd 100644 --- a/bot/cogs/utils/extensions.py +++ b/bot/cogs/utils/extensions.py @@ -1,13 +1,16 @@ import functools +import importlib +import inspect import logging +import pkgutil import typing as t from enum import Enum -from pkgutil import iter_modules from discord import Colour, Embed from discord.ext import commands from discord.ext.commands import Context, group +from bot import cogs from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs from bot.pagination import LinePaginator @@ -15,12 +18,29 @@ from bot.utils.checks import with_role_check log = logging.getLogger(__name__) -UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} -EXTENSIONS = frozenset( - ext.name - for ext in iter_modules(("bot/cogs",), "bot.cogs.") - if ext.name[-1] != "_" -) + +def walk_extensions() -> t.Iterator[str]: + """Yield extension names from the bot.cogs subpackage.""" + + def on_error(name: str) -> t.NoReturn: + raise ImportError(name=name) # pragma: no cover + + for module in pkgutil.walk_packages(cogs.__path__, f"{cogs.__name__}.", onerror=on_error): + if module.name.rsplit(".", maxsplit=1)[-1].startswith("_"): + # Ignore module/package names starting with an underscore. + continue + + if module.ispkg: + imported = importlib.import_module(module.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + yield module.name + + +UNLOAD_BLACKLIST = {f"{cogs.__name__}.utils.extensions", f"{cogs.__name__}.moderation.modlog"} +EXTENSIONS = frozenset(walk_extensions()) class Action(Enum): @@ -48,7 +68,7 @@ class Extension(commands.Converter): argument = argument.lower() if "." not in argument: - argument = f"bot.cogs.{argument}" + argument = f"{cogs.__name__}.{argument}" if argument in EXTENSIONS: return argument |