diff options
| author | 2019-10-07 10:23:13 -0700 | |
|---|---|---|
| committer | 2019-10-07 10:23:13 -0700 | |
| commit | 77216353a87bcf2dbf67cfe028f9f38ba7a2406e (patch) | |
| tree | aa7f131c7418397462a594d4b51fc07197751767 | |
| parent | Use quotes instead of back ticks around asterisk in docstrings (diff) | |
Support wildcards and multiple extensions for load and unload commands
* Rename batch_reload() to batch_manage() and make it accept an
  action as a parameter so that it can be a generic function.
* Switch parameter order for manage() to make it consistent with
  batch_manage().
* Always call batch_manage() and make it defer to manage() when only 1
  extension is given.
* Make batch_manage() a regular method instead of a coroutine.
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/extensions.py | 84 | 
1 files changed, 48 insertions, 36 deletions
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index a385e50d5..5f9b4aef4 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -41,7 +41,7 @@ class Extension(commands.Converter):      async def convert(self, ctx: Context, argument: str) -> str:          """Fully qualify the name of an extension and ensure it exists."""          # Special values to reload all extensions -        if ctx.command.name == "reload" and (argument == "*" or argument == "**"): +        if argument == "*" or argument == "**":              return argument          argument = argument.lower() @@ -67,18 +67,34 @@ class Extensions(commands.Cog):          await ctx.invoke(self.bot.get_command("help"), "extensions")      @extensions_group.command(name="load", aliases=("l",)) -    async def load_command(self, ctx: Context, extension: Extension) -> None: -        """Load an extension given its fully qualified or unqualified name.""" -        msg, _ = self.manage(extension, Action.LOAD) +    async def load_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Load extensions given their fully qualified or unqualified names. + +        If '*' or '**' is given as the name, all unloaded extensions will be loaded. +        """ +        if "*" in extensions or "**" in extensions: +            extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + +        msg = self.batch_manage(Action.LOAD, *extensions)          await ctx.send(msg)      @extensions_group.command(name="unload", aliases=("ul",)) -    async def unload_command(self, ctx: Context, extension: Extension) -> None: -        """Unload a currently loaded extension given its fully qualified or unqualified name.""" -        if extension in UNLOAD_BLACKLIST: -            msg = f":x: The extension `{extension}` may not be unloaded." +    async def unload_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Unload currently loaded extensions given their fully qualified or unqualified names. + +        If '*' or '**' is given as the name, all loaded extensions will be unloaded. +        """ +        blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + +        if blacklisted: +            msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```"          else: -            msg, _ = self.manage(extension, Action.UNLOAD) +            if "*" in extensions or "**" in extensions: +                extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + +            msg = self.batch_manage(Action.UNLOAD, *extensions)          await ctx.send(msg) @@ -96,10 +112,13 @@ class Extensions(commands.Cog):              await ctx.invoke(self.bot.get_command("help"), "extensions reload")              return -        if len(extensions) > 1: -            msg = await self.batch_reload(*extensions) -        else: -            msg, _ = self.manage(extensions[0], Action.RELOAD) +        if "**" in extensions: +            extensions = EXTENSIONS +        elif "*" in extensions: +            extensions = set(self.bot.extensions.keys()) | set(extensions) +            extensions.remove("*") + +        msg = self.batch_manage(Action.RELOAD, *extensions)          await ctx.send(msg) @@ -133,43 +152,36 @@ class Extensions(commands.Cog):          log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")          await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) -    async def batch_reload(self, *extensions: str) -> str: +    def batch_manage(self, action: Action, *extensions: str) -> str:          """ -        Reload given extensions and return a message with the results. +        Apply an action to multiple extensions and return a message with the results. -        If '*' is given, all currently loaded extensions will be reloaded along with any other -        specified extensions. If '**' is given, all extensions, including unloaded ones, will be -        reloaded. +        If only one extension is given, it is deferred to `manage()`.          """ -        failures = {} +        if len(extensions) == 1: +            msg, _ = self.manage(action, extensions[0]) +            return msg -        if "**" in extensions: -            to_reload = EXTENSIONS -        elif "*" in extensions: -            to_reload = set(self.bot.extensions.keys()) | set(extensions) -            to_reload.remove("*") -        elif extensions: -            to_reload = extensions -        else: -            to_reload = self.bot.extensions.copy().keys() +        verb = action.name.lower() +        failures = {} -        for extension in to_reload: -            _, error = self.manage(extension, Action.RELOAD) +        for extension in extensions: +            _, error = self.manage(action, extension)              if error:                  failures[extension] = error          emoji = ":x:" if failures else ":ok_hand:" -        msg = f"{emoji} {len(to_reload) - len(failures)} / {len(to_reload)} extensions reloaded." +        msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed."          if failures: -            failures = '\n'.join(f'{ext}\n    {err}' for ext, err in failures.items()) -            msg += f'\nFailures:```{failures}```' +            failures = "\n".join(f"{ext}\n    {err}" for ext, err in failures.items()) +            msg += f"\nFailures:```{failures}```" -        log.debug(f'Reloaded all extensions.') +        log.debug(f"Batch {verb}ed extensions.")          return msg -    def manage(self, ext: str, action: Action) -> t.Tuple[str, t.Optional[str]]: +    def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:          """Apply an action to an extension and return the status message and any error message."""          verb = action.name.lower()          error_msg = None @@ -179,7 +191,7 @@ class Extensions(commands.Cog):          except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):              if action is Action.RELOAD:                  # When reloading, just load the extension if it was not loaded. -                return self.manage(ext, Action.LOAD) +                return self.manage(Action.LOAD, ext)              msg = f":x: Extension `{ext}` is already {verb}ed."              log.debug(msg[4:])  |