aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/extensions.py84
1 files changed, 48 insertions, 36 deletions
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index a385e50d5..5f9b4aef4 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -41,7 +41,7 @@ class Extension(commands.Converter):
async def convert(self, ctx: Context, argument: str) -> str:
"""Fully qualify the name of an extension and ensure it exists."""
# Special values to reload all extensions
- if ctx.command.name == "reload" and (argument == "*" or argument == "**"):
+ if argument == "*" or argument == "**":
return argument
argument = argument.lower()
@@ -67,18 +67,34 @@ class Extensions(commands.Cog):
await ctx.invoke(self.bot.get_command("help"), "extensions")
@extensions_group.command(name="load", aliases=("l",))
- async def load_command(self, ctx: Context, extension: Extension) -> None:
- """Load an extension given its fully qualified or unqualified name."""
- msg, _ = self.manage(extension, Action.LOAD)
+ async def load_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Load extensions given their fully qualified or unqualified names.
+
+ If '*' or '**' is given as the name, all unloaded extensions will be loaded.
+ """
+ if "*" in extensions or "**" in extensions:
+ extensions = set(EXTENSIONS) - set(self.bot.extensions.keys())
+
+ msg = self.batch_manage(Action.LOAD, *extensions)
await ctx.send(msg)
@extensions_group.command(name="unload", aliases=("ul",))
- async def unload_command(self, ctx: Context, extension: Extension) -> None:
- """Unload a currently loaded extension given its fully qualified or unqualified name."""
- if extension in UNLOAD_BLACKLIST:
- msg = f":x: The extension `{extension}` may not be unloaded."
+ async def unload_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Unload currently loaded extensions given their fully qualified or unqualified names.
+
+ If '*' or '**' is given as the name, all loaded extensions will be unloaded.
+ """
+ blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))
+
+ if blacklisted:
+ msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```"
else:
- msg, _ = self.manage(extension, Action.UNLOAD)
+ if "*" in extensions or "**" in extensions:
+ extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST
+
+ msg = self.batch_manage(Action.UNLOAD, *extensions)
await ctx.send(msg)
@@ -96,10 +112,13 @@ class Extensions(commands.Cog):
await ctx.invoke(self.bot.get_command("help"), "extensions reload")
return
- if len(extensions) > 1:
- msg = await self.batch_reload(*extensions)
- else:
- msg, _ = self.manage(extensions[0], Action.RELOAD)
+ if "**" in extensions:
+ extensions = EXTENSIONS
+ elif "*" in extensions:
+ extensions = set(self.bot.extensions.keys()) | set(extensions)
+ extensions.remove("*")
+
+ msg = self.batch_manage(Action.RELOAD, *extensions)
await ctx.send(msg)
@@ -133,43 +152,36 @@ class Extensions(commands.Cog):
log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
- async def batch_reload(self, *extensions: str) -> str:
+ def batch_manage(self, action: Action, *extensions: str) -> str:
"""
- Reload given extensions and return a message with the results.
+ Apply an action to multiple extensions and return a message with the results.
- If '*' is given, all currently loaded extensions will be reloaded along with any other
- specified extensions. If '**' is given, all extensions, including unloaded ones, will be
- reloaded.
+ If only one extension is given, it is deferred to `manage()`.
"""
- failures = {}
+ if len(extensions) == 1:
+ msg, _ = self.manage(action, extensions[0])
+ return msg
- if "**" in extensions:
- to_reload = EXTENSIONS
- elif "*" in extensions:
- to_reload = set(self.bot.extensions.keys()) | set(extensions)
- to_reload.remove("*")
- elif extensions:
- to_reload = extensions
- else:
- to_reload = self.bot.extensions.copy().keys()
+ verb = action.name.lower()
+ failures = {}
- for extension in to_reload:
- _, error = self.manage(extension, Action.RELOAD)
+ for extension in extensions:
+ _, error = self.manage(action, extension)
if error:
failures[extension] = error
emoji = ":x:" if failures else ":ok_hand:"
- msg = f"{emoji} {len(to_reload) - len(failures)} / {len(to_reload)} extensions reloaded."
+ msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed."
if failures:
- failures = '\n'.join(f'{ext}\n {err}' for ext, err in failures.items())
- msg += f'\nFailures:```{failures}```'
+ failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items())
+ msg += f"\nFailures:```{failures}```"
- log.debug(f'Reloaded all extensions.')
+ log.debug(f"Batch {verb}ed extensions.")
return msg
- def manage(self, ext: str, action: Action) -> t.Tuple[str, t.Optional[str]]:
+ def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
"""Apply an action to an extension and return the status message and any error message."""
verb = action.name.lower()
error_msg = None
@@ -179,7 +191,7 @@ class Extensions(commands.Cog):
except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
if action is Action.RELOAD:
# When reloading, just load the extension if it was not loaded.
- return self.manage(ext, Action.LOAD)
+ return self.manage(Action.LOAD, ext)
msg = f":x: Extension `{ext}` is already {verb}ed."
log.debug(msg[4:])