aboutsummaryrefslogtreecommitdiffstats
path: root/bot/exts
diff options
context:
space:
mode:
Diffstat (limited to 'bot/exts')
-rw-r--r--bot/exts/__init__.py23
-rw-r--r--bot/exts/evergreen/snakes/__init__.py2
-rw-r--r--bot/exts/evergreen/snakes/_converter.py (renamed from bot/exts/evergreen/snakes/converter.py)2
-rw-r--r--bot/exts/evergreen/snakes/_snakes_cog.py (renamed from bot/exts/evergreen/snakes/snakes_cog.py)4
-rw-r--r--bot/exts/evergreen/snakes/_utils.py (renamed from bot/exts/evergreen/snakes/utils.py)0
-rw-r--r--bot/exts/utils/__init__.py0
-rw-r--r--bot/exts/utils/extensions.py265
7 files changed, 270 insertions, 26 deletions
diff --git a/bot/exts/__init__.py b/bot/exts/__init__.py
index 25deb9af..13f484ac 100644
--- a/bot/exts/__init__.py
+++ b/bot/exts/__init__.py
@@ -1,9 +1,8 @@
import logging
import pkgutil
-from pathlib import Path
from typing import Iterator
-__all__ = ("get_package_names", "walk_extensions")
+__all__ = ("get_package_names",)
log = logging.getLogger(__name__)
@@ -13,23 +12,3 @@ def get_package_names() -> Iterator[str]:
for package in pkgutil.iter_modules(__path__):
if package.ispkg:
yield package.name
-
-
-def walk_extensions() -> Iterator[str]:
- """
- Iterate dot-separated paths to all extensions.
-
- The strings are formatted in a way such that the bot's `load_extension`
- method can take them. Use this to load all available extensions.
-
- This intentionally doesn't make use of pkgutil's `walk_packages`, as we only
- want to build paths to extensions - not recursively all modules. For some
- extensions, the `setup` function is in the package's __init__ file, while
- modules nested under the package are only helpers. Constructing the paths
- ourselves serves our purpose better.
- """
- base_path = Path(__path__[0])
-
- for package in get_package_names():
- for extension in pkgutil.iter_modules([base_path.joinpath(package)]):
- yield f"bot.exts.{package}.{extension.name}"
diff --git a/bot/exts/evergreen/snakes/__init__.py b/bot/exts/evergreen/snakes/__init__.py
index 2eae2751..bc42f0c2 100644
--- a/bot/exts/evergreen/snakes/__init__.py
+++ b/bot/exts/evergreen/snakes/__init__.py
@@ -2,7 +2,7 @@ import logging
from discord.ext import commands
-from bot.exts.evergreen.snakes.snakes_cog import Snakes
+from bot.exts.evergreen.snakes._snakes_cog import Snakes
log = logging.getLogger(__name__)
diff --git a/bot/exts/evergreen/snakes/converter.py b/bot/exts/evergreen/snakes/_converter.py
index 55609b8e..eee248cf 100644
--- a/bot/exts/evergreen/snakes/converter.py
+++ b/bot/exts/evergreen/snakes/_converter.py
@@ -7,7 +7,7 @@ import discord
from discord.ext.commands import Context, Converter
from fuzzywuzzy import fuzz
-from bot.exts.evergreen.snakes.utils import SNAKE_RESOURCES
+from bot.exts.evergreen.snakes._utils import SNAKE_RESOURCES
from bot.utils import disambiguate
log = logging.getLogger(__name__)
diff --git a/bot/exts/evergreen/snakes/snakes_cog.py b/bot/exts/evergreen/snakes/_snakes_cog.py
index 9bbad9fe..a846274b 100644
--- a/bot/exts/evergreen/snakes/snakes_cog.py
+++ b/bot/exts/evergreen/snakes/_snakes_cog.py
@@ -18,8 +18,8 @@ from discord import Colour, Embed, File, Member, Message, Reaction
from discord.ext.commands import BadArgument, Bot, Cog, CommandError, Context, bot_has_permissions, group
from bot.constants import ERROR_REPLIES, Tokens
-from bot.exts.evergreen.snakes import utils
-from bot.exts.evergreen.snakes.converter import Snake
+from bot.exts.evergreen.snakes import _utils as utils
+from bot.exts.evergreen.snakes._converter import Snake
from bot.utils.decorators import locked
log = logging.getLogger(__name__)
diff --git a/bot/exts/evergreen/snakes/utils.py b/bot/exts/evergreen/snakes/_utils.py
index 7d6caf04..7d6caf04 100644
--- a/bot/exts/evergreen/snakes/utils.py
+++ b/bot/exts/evergreen/snakes/_utils.py
diff --git a/bot/exts/utils/__init__.py b/bot/exts/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/bot/exts/utils/__init__.py
diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py
new file mode 100644
index 00000000..102a0416
--- /dev/null
+++ b/bot/exts/utils/extensions.py
@@ -0,0 +1,265 @@
+import functools
+import logging
+import typing as t
+from enum import Enum
+
+from discord import Colour, Embed
+from discord.ext import commands
+from discord.ext.commands import Context, group
+
+from bot import exts
+from bot.bot import SeasonalBot as Bot
+from bot.constants import Client, Emojis, MODERATION_ROLES, Roles
+from bot.utils.checks import with_role_check
+from bot.utils.extensions import EXTENSIONS, unqualify
+from bot.utils.pagination import LinePaginator
+
+log = logging.getLogger(__name__)
+
+
+UNLOAD_BLACKLIST = {f"{exts.__name__}.utils.extensions"}
+BASE_PATH_LEN = len(exts.__name__.split("."))
+
+
+class Action(Enum):
+ """Represents an action to perform on an extension."""
+
+ # Need to be partial otherwise they are considered to be function definitions.
+ LOAD = functools.partial(Bot.load_extension)
+ UNLOAD = functools.partial(Bot.unload_extension)
+ RELOAD = functools.partial(Bot.reload_extension)
+
+
+class Extension(commands.Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if argument == "*" or argument == "**":
+ return argument
+
+ argument = argument.lower()
+
+ if argument in EXTENSIONS:
+ return argument
+ elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS:
+ return qualified_arg
+
+ matches = []
+ for ext in EXTENSIONS:
+ if argument == unqualify(ext):
+ matches.append(ext)
+
+ if len(matches) > 1:
+ matches.sort()
+ names = "\n".join(matches)
+ raise commands.BadArgument(
+ f":x: `{argument}` is an ambiguous extension name. "
+ f"Please use one of the following fully-qualified names.```\n{names}```"
+ )
+ elif matches:
+ return matches[0]
+ else:
+ raise commands.BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
+class Extensions(commands.Cog):
+ """Extension management commands."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True)
+ async def extensions_group(self, ctx: Context) -> None:
+ """Load, unload, reload, and list loaded extensions."""
+ await ctx.send_help(ctx.command)
+
+ @extensions_group.command(name="load", aliases=("l",))
+ async def load_command(self, ctx: Context, *extensions: Extension) -> None:
+ r"""
+ Load extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.send_help(ctx.command)
+ return
+
+ if "*" in extensions or "**" in extensions:
+ extensions = set(EXTENSIONS) - set(self.bot.extensions.keys())
+
+ msg = self.batch_manage(Action.LOAD, *extensions)
+ await ctx.send(msg)
+
+ @extensions_group.command(name="unload", aliases=("ul",))
+ async def unload_command(self, ctx: Context, *extensions: Extension) -> None:
+ r"""
+ Unload currently loaded extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.send_help(ctx.command)
+ return
+
+ blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))
+
+ if blacklisted:
+ msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```"
+ else:
+ if "*" in extensions or "**" in extensions:
+ extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST
+
+ msg = self.batch_manage(Action.UNLOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",))
+ async def reload_command(self, ctx: Context, *extensions: Extension) -> None:
+ r"""
+ Reload extensions given their fully qualified or unqualified names.
+
+ If an extension fails to be reloaded, it will be rolled-back to the prior working state.
+
+ If '\*' is given as the name, all currently loaded extensions will be reloaded.
+ If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.send_help(ctx.command)
+ return
+
+ if "**" in extensions:
+ extensions = EXTENSIONS
+ elif "*" in extensions:
+ extensions = set(self.bot.extensions.keys()) | set(extensions)
+ extensions.remove("*")
+
+ msg = self.batch_manage(Action.RELOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="list", aliases=("all",))
+ async def list_command(self, ctx: Context) -> None:
+ """
+ Get a list of all extensions, including their loaded status.
+
+ Grey indicates that the extension is unloaded.
+ Green indicates that the extension is currently loaded.
+ """
+ embed = Embed(colour=Colour.blurple())
+ embed.set_author(
+ name="Extensions List",
+ url=Client.github_bot_repo,
+ icon_url=str(self.bot.user.avatar_url)
+ )
+
+ lines = []
+ categories = self.group_extension_statuses()
+ for category, extensions in sorted(categories.items()):
+ # Treat each category as a single line by concatenating everything.
+ # This ensures the paginator will not cut off a page in the middle of a category.
+ category = category.replace("_", " ").title()
+ extensions = "\n".join(sorted(extensions))
+ lines.append(f"**{category}**\n{extensions}\n")
+
+ log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
+ await LinePaginator.paginate(lines, ctx, embed, max_size=1200, empty=False)
+
+ def group_extension_statuses(self) -> t.Mapping[str, str]:
+ """Return a mapping of extension names and statuses to their categories."""
+ categories = {}
+
+ for ext in EXTENSIONS:
+ if ext in self.bot.extensions:
+ status = Emojis.status_online
+ else:
+ status = Emojis.status_offline
+
+ path = ext.split(".")
+ if len(path) > BASE_PATH_LEN + 1:
+ category = " - ".join(path[BASE_PATH_LEN:-1])
+ else:
+ category = "uncategorised"
+
+ categories.setdefault(category, []).append(f"{status} {path[-1]}")
+
+ return categories
+
+ def batch_manage(self, action: Action, *extensions: str) -> str:
+ """
+ Apply an action to multiple extensions and return a message with the results.
+
+ If only one extension is given, it is deferred to `manage()`.
+ """
+ if len(extensions) == 1:
+ msg, _ = self.manage(action, extensions[0])
+ return msg
+
+ verb = action.name.lower()
+ failures = {}
+
+ for extension in extensions:
+ _, error = self.manage(action, extension)
+ if error:
+ failures[extension] = error
+
+ emoji = ":x:" if failures else ":ok_hand:"
+ msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed."
+
+ if failures:
+ failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items())
+ msg += f"\nFailures:```{failures}```"
+
+ log.debug(f"Batch {verb}ed extensions.")
+
+ return msg
+
+ def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
+ """Apply an action to an extension and return the status message and any error message."""
+ verb = action.name.lower()
+ error_msg = None
+
+ try:
+ action.value(self.bot, ext)
+ except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
+ if action is Action.RELOAD:
+ # When reloading, just load the extension if it was not loaded.
+ return self.manage(Action.LOAD, ext)
+
+ msg = f":x: Extension `{ext}` is already {verb}ed."
+ log.debug(msg[4:])
+ except Exception as e:
+ if hasattr(e, "original"):
+ e = e.original
+
+ log.exception(f"Extension '{ext}' failed to {verb}.")
+
+ error_msg = f"{e.__class__.__name__}: {e}"
+ msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```"
+ else:
+ msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`."
+ log.debug(msg[10:])
+
+ return msg, error_msg
+
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators and core developers to invoke the commands in this cog."""
+ return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers)
+
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Handle BadArgument errors locally to prevent the help command from showing."""
+ if isinstance(error, commands.BadArgument):
+ await ctx.send(str(error))
+ error.handled = True
+
+
+def setup(bot: Bot) -> None:
+ """Load the Extensions cog."""
+ bot.add_cog(Extensions(bot))