diff options
author | 2019-10-07 10:23:13 -0700 | |
---|---|---|
committer | 2019-10-07 10:23:13 -0700 | |
commit | 77216353a87bcf2dbf67cfe028f9f38ba7a2406e (patch) | |
tree | aa7f131c7418397462a594d4b51fc07197751767 | |
parent | Use quotes instead of back ticks around asterisk in docstrings (diff) |
Support wildcards and multiple extensions for load and unload commands
* Rename batch_reload() to batch_manage() and make it accept an
action as a parameter so that it can be a generic function.
* Switch parameter order for manage() to make it consistent with
batch_manage().
* Always call batch_manage() and make it defer to manage() when only 1
extension is given.
* Make batch_manage() a regular method instead of a coroutine.
-rw-r--r-- | bot/cogs/extensions.py | 84 |
1 files changed, 48 insertions, 36 deletions
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index a385e50d5..5f9b4aef4 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -41,7 +41,7 @@ class Extension(commands.Converter): async def convert(self, ctx: Context, argument: str) -> str: """Fully qualify the name of an extension and ensure it exists.""" # Special values to reload all extensions - if ctx.command.name == "reload" and (argument == "*" or argument == "**"): + if argument == "*" or argument == "**": return argument argument = argument.lower() @@ -67,18 +67,34 @@ class Extensions(commands.Cog): await ctx.invoke(self.bot.get_command("help"), "extensions") @extensions_group.command(name="load", aliases=("l",)) - async def load_command(self, ctx: Context, extension: Extension) -> None: - """Load an extension given its fully qualified or unqualified name.""" - msg, _ = self.manage(extension, Action.LOAD) + async def load_command(self, ctx: Context, *extensions: Extension) -> None: + """ + Load extensions given their fully qualified or unqualified names. + + If '*' or '**' is given as the name, all unloaded extensions will be loaded. + """ + if "*" in extensions or "**" in extensions: + extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + + msg = self.batch_manage(Action.LOAD, *extensions) await ctx.send(msg) @extensions_group.command(name="unload", aliases=("ul",)) - async def unload_command(self, ctx: Context, extension: Extension) -> None: - """Unload a currently loaded extension given its fully qualified or unqualified name.""" - if extension in UNLOAD_BLACKLIST: - msg = f":x: The extension `{extension}` may not be unloaded." + async def unload_command(self, ctx: Context, *extensions: Extension) -> None: + """ + Unload currently loaded extensions given their fully qualified or unqualified names. + + If '*' or '**' is given as the name, all loaded extensions will be unloaded. + """ + blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + + if blacklisted: + msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" else: - msg, _ = self.manage(extension, Action.UNLOAD) + if "*" in extensions or "**" in extensions: + extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + + msg = self.batch_manage(Action.UNLOAD, *extensions) await ctx.send(msg) @@ -96,10 +112,13 @@ class Extensions(commands.Cog): await ctx.invoke(self.bot.get_command("help"), "extensions reload") return - if len(extensions) > 1: - msg = await self.batch_reload(*extensions) - else: - msg, _ = self.manage(extensions[0], Action.RELOAD) + if "**" in extensions: + extensions = EXTENSIONS + elif "*" in extensions: + extensions = set(self.bot.extensions.keys()) | set(extensions) + extensions.remove("*") + + msg = self.batch_manage(Action.RELOAD, *extensions) await ctx.send(msg) @@ -133,43 +152,36 @@ class Extensions(commands.Cog): log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) - async def batch_reload(self, *extensions: str) -> str: + def batch_manage(self, action: Action, *extensions: str) -> str: """ - Reload given extensions and return a message with the results. + Apply an action to multiple extensions and return a message with the results. - If '*' is given, all currently loaded extensions will be reloaded along with any other - specified extensions. If '**' is given, all extensions, including unloaded ones, will be - reloaded. + If only one extension is given, it is deferred to `manage()`. """ - failures = {} + if len(extensions) == 1: + msg, _ = self.manage(action, extensions[0]) + return msg - if "**" in extensions: - to_reload = EXTENSIONS - elif "*" in extensions: - to_reload = set(self.bot.extensions.keys()) | set(extensions) - to_reload.remove("*") - elif extensions: - to_reload = extensions - else: - to_reload = self.bot.extensions.copy().keys() + verb = action.name.lower() + failures = {} - for extension in to_reload: - _, error = self.manage(extension, Action.RELOAD) + for extension in extensions: + _, error = self.manage(action, extension) if error: failures[extension] = error emoji = ":x:" if failures else ":ok_hand:" - msg = f"{emoji} {len(to_reload) - len(failures)} / {len(to_reload)} extensions reloaded." + msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." if failures: - failures = '\n'.join(f'{ext}\n {err}' for ext, err in failures.items()) - msg += f'\nFailures:```{failures}```' + failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items()) + msg += f"\nFailures:```{failures}```" - log.debug(f'Reloaded all extensions.') + log.debug(f"Batch {verb}ed extensions.") return msg - def manage(self, ext: str, action: Action) -> t.Tuple[str, t.Optional[str]]: + def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: """Apply an action to an extension and return the status message and any error message.""" verb = action.name.lower() error_msg = None @@ -179,7 +191,7 @@ class Extensions(commands.Cog): except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): if action is Action.RELOAD: # When reloading, just load the extension if it was not loaded. - return self.manage(ext, Action.LOAD) + return self.manage(Action.LOAD, ext) msg = f":x: Extension `{ext}` is already {verb}ed." log.debug(msg[4:]) |