aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2019-10-03 19:06:27 -0700
committerGravatar MarkKoz <[email protected]>2019-10-03 19:07:17 -0700
commita01a969512b8eb11a337b9c5292bae1d678429a2 (patch)
treeaea8aab4380094fa148a835366b323b8755b8170
parentAdd a generic method to manage loading/unloading extensions (diff)
Add a custom converter for extensions
The converter fully qualifies the extension's name and ensures the extension exists. * Make the extensions set a module constant instead of an instant attribute and make it a frozenset. * Add a cog error handler to handle BadArgument locally and prevent the help command from showing for such errors.
-rw-r--r--bot/cogs/extensions.py41
1 files changed, 33 insertions, 8 deletions
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index 83048bb76..e50ef5553 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -5,7 +5,7 @@ from enum import Enum
from pkgutil import iter_modules
from discord import Colour, Embed
-from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group
from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
from bot.pagination import LinePaginator
@@ -14,6 +14,7 @@ from bot.utils.checks import with_role_check
log = logging.getLogger(__name__)
KEEP_LOADED = ["bot.cogs.extensions", "bot.cogs.modlog"]
+EXTENSIONS = frozenset(ext for ext in iter_modules(("bot/cogs", "bot.cogs")) if ext.name[-1] != "_")
class Action(Enum):
@@ -24,16 +25,36 @@ class Action(Enum):
RELOAD = (Bot.unload_extension, Bot.load_extension)
+class Extension(Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if ctx.command.name == "reload" and (argument == "*" or argument == "**"):
+ return argument
+
+ argument = argument.lower()
+
+ if "." not in argument:
+ argument = f"bot.cogs.{argument}"
+
+ if argument in EXTENSIONS:
+ return argument
+ else:
+ raise BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
class Extensions(Cog):
"""Extension management commands."""
def __init__(self, bot: Bot):
self.bot = bot
- log.info("Initialising extension names...")
- modules = iter_modules(("bot/cogs", "bot.cogs"))
- self.cogs = set(ext for ext in modules if ext.name[-1] != "_")
-
@group(name='extensions', aliases=('c', 'ext', 'exts'), invoke_without_command=True)
async def extensions_group(self, ctx: Context) -> None:
"""Load, unload, reload, and list active cogs."""
@@ -291,9 +312,6 @@ class Extensions(Cog):
verb = action.name.lower()
error_msg = None
- if ext not in self.cogs:
- return f":x: Extension {ext} does not exist.", None
-
if (
(action is Action.LOAD and ext not in self.bot.extensions)
or (action is Action.UNLOAD and ext in self.bot.extensions)
@@ -321,6 +339,13 @@ class Extensions(Cog):
"""Only allow moderators and core developers to invoke the commands in this cog."""
return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer)
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Handle BadArgument errors locally to prevent the help command from showing."""
+ if isinstance(error, BadArgument):
+ await ctx.send(str(error))
+ error.handled = True
+
def setup(bot: Bot) -> None:
"""Load the Extensions cog."""