diff options
| author | 2019-10-03 19:06:27 -0700 | |
|---|---|---|
| committer | 2019-10-03 19:07:17 -0700 | |
| commit | a01a969512b8eb11a337b9c5292bae1d678429a2 (patch) | |
| tree | aea8aab4380094fa148a835366b323b8755b8170 | |
| parent | Add a generic method to manage loading/unloading extensions (diff) | |
Add a custom converter for extensions
The converter fully qualifies the extension's name and ensures the
extension exists.
* Make the extensions set a module constant instead of an instant
  attribute and make it a frozenset.
* Add a cog error handler to handle BadArgument locally and prevent the
  help command from showing for such errors.
| -rw-r--r-- | bot/cogs/extensions.py | 41 | 
1 files changed, 33 insertions, 8 deletions
| diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index 83048bb76..e50ef5553 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -5,7 +5,7 @@ from enum import Enum  from pkgutil import iter_modules  from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group  from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs  from bot.pagination import LinePaginator @@ -14,6 +14,7 @@ from bot.utils.checks import with_role_check  log = logging.getLogger(__name__)  KEEP_LOADED = ["bot.cogs.extensions", "bot.cogs.modlog"] +EXTENSIONS = frozenset(ext for ext in iter_modules(("bot/cogs", "bot.cogs")) if ext.name[-1] != "_")  class Action(Enum): @@ -24,16 +25,36 @@ class Action(Enum):      RELOAD = (Bot.unload_extension, Bot.load_extension) +class Extension(Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if ctx.command.name == "reload" and (argument == "*" or argument == "**"): +            return argument + +        argument = argument.lower() + +        if "." not in argument: +            argument = f"bot.cogs.{argument}" + +        if argument in EXTENSIONS: +            return argument +        else: +            raise BadArgument(f":x: Could not find the extension `{argument}`.") + +  class Extensions(Cog):      """Extension management commands."""      def __init__(self, bot: Bot):          self.bot = bot -        log.info("Initialising extension names...") -        modules = iter_modules(("bot/cogs", "bot.cogs")) -        self.cogs = set(ext for ext in modules if ext.name[-1] != "_") -      @group(name='extensions', aliases=('c', 'ext', 'exts'), invoke_without_command=True)      async def extensions_group(self, ctx: Context) -> None:          """Load, unload, reload, and list active cogs.""" @@ -291,9 +312,6 @@ class Extensions(Cog):          verb = action.name.lower()          error_msg = None -        if ext not in self.cogs: -            return f":x: Extension {ext} does not exist.", None -          if (              (action is Action.LOAD and ext not in self.bot.extensions)              or (action is Action.UNLOAD and ext in self.bot.extensions) @@ -321,6 +339,13 @@ class Extensions(Cog):          """Only allow moderators and core developers to invoke the commands in this cog."""          return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer) +    # This cannot be static (must have a __func__ attribute). +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Handle BadArgument errors locally to prevent the help command from showing.""" +        if isinstance(error, BadArgument): +            await ctx.send(str(error)) +            error.handled = True +  def setup(bot: Bot) -> None:      """Load the Extensions cog.""" | 
