aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/__main__.py2
-rw-r--r--bot/cogs/alias.py7
-rw-r--r--bot/cogs/cogs.py298
-rw-r--r--bot/cogs/extensions.py236
4 files changed, 241 insertions, 302 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index d0924be78..19a7e5ec6 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -42,7 +42,7 @@ bot.load_extension("bot.cogs.security")
bot.load_extension("bot.cogs.antispam")
bot.load_extension("bot.cogs.bot")
bot.load_extension("bot.cogs.clean")
-bot.load_extension("bot.cogs.cogs")
+bot.load_extension("bot.cogs.extensions")
bot.load_extension("bot.cogs.help")
# Only load this in production
diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py
index 0f49a400c..6648805e9 100644
--- a/bot/cogs/alias.py
+++ b/bot/cogs/alias.py
@@ -5,6 +5,7 @@ from typing import Union
from discord import Colour, Embed, Member, User
from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group
+from bot.cogs.extensions import Extension
from bot.cogs.watchchannels.watchchannel import proxy_user
from bot.converters import TagNameConverter
from bot.pagination import LinePaginator
@@ -84,9 +85,9 @@ class Alias (Cog):
await self.invoke(ctx, "site rules")
@command(name="reload", hidden=True)
- async def cogs_reload_alias(self, ctx: Context, *, cog_name: str) -> None:
- """Alias for invoking <prefix>cogs reload [cog_name]."""
- await self.invoke(ctx, "cogs reload", cog_name)
+ async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None:
+ """Alias for invoking <prefix>extensions reload [extensions...]."""
+ await self.invoke(ctx, "extensions reload", *extensions)
@command(name="defon", hidden=True)
async def defcon_enable_alias(self, ctx: Context) -> None:
diff --git a/bot/cogs/cogs.py b/bot/cogs/cogs.py
deleted file mode 100644
index 1f6ccd09c..000000000
--- a/bot/cogs/cogs.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import logging
-import os
-
-from discord import Colour, Embed
-from discord.ext.commands import Bot, Cog, Context, group
-
-from bot.constants import (
- Emojis, MODERATION_ROLES, Roles, URLs
-)
-from bot.decorators import with_role
-from bot.pagination import LinePaginator
-
-log = logging.getLogger(__name__)
-
-KEEP_LOADED = ["bot.cogs.cogs", "bot.cogs.modlog"]
-
-
-class Cogs(Cog):
- """Cog management commands."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
- self.cogs = {}
-
- # Load up the cog names
- log.info("Initializing cog names...")
- for filename in os.listdir("bot/cogs"):
- if filename.endswith(".py") and "_" not in filename:
- if os.path.isfile(f"bot/cogs/{filename}"):
- cog = filename[:-3]
-
- self.cogs[cog] = f"bot.cogs.{cog}"
-
- # Allow reverse lookups by reversing the pairs
- self.cogs.update({v: k for k, v in self.cogs.items()})
-
- @group(name='cogs', aliases=('c',), invoke_without_command=True)
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def cogs_group(self, ctx: Context) -> None:
- """Load, unload, reload, and list active cogs."""
- await ctx.invoke(self.bot.get_command("help"), "cogs")
-
- @cogs_group.command(name='load', aliases=('l',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def load_command(self, ctx: Context, cog: str) -> None:
- """
- Load up an unloaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we load the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog not in self.bot.extensions:
- try:
- self.bot.load_extension(full_cog)
- except ImportError:
- log.exception(f"{ctx.author} requested we load the '{cog}' cog, "
- f"but the cog module {full_cog} could not be found!")
- embed.description = f"Invalid cog: {cog}\n\nCould not find cog module {full_cog}"
- except Exception as e:
- log.exception(f"{ctx.author} requested we load the '{cog}' cog, "
- "but the loading failed")
- embed.description = f"Failed to load cog: {cog}\n\n{e.__class__.__name__}: {e}"
- else:
- log.debug(f"{ctx.author} requested we load the '{cog}' cog. Cog loaded!")
- embed.description = f"Cog loaded: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we load the '{cog}' cog, but the cog was already loaded!")
- embed.description = f"Cog {cog} is already loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='unload', aliases=('ul',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def unload_command(self, ctx: Context, cog: str) -> None:
- """
- Unload an already-loaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog in KEEP_LOADED:
- log.warning(f"{ctx.author} requested we unload `{full_cog}`, that sneaky pete. We said no.")
- embed.description = f"You may not unload `{full_cog}`!"
- elif full_cog in self.bot.extensions:
- try:
- self.bot.unload_extension(full_cog)
- except Exception as e:
- log.exception(f"{ctx.author} requested we unload the '{cog}' cog, "
- "but the unloading failed")
- embed.description = f"Failed to unload cog: {cog}\n\n```{e}```"
- else:
- log.debug(f"{ctx.author} requested we unload the '{cog}' cog. Cog unloaded!")
- embed.description = f"Cog unloaded: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but the cog wasn't loaded!")
- embed.description = f"Cog {cog} is not loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='reload', aliases=('r',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def reload_command(self, ctx: Context, cog: str) -> None:
- """
- Reload an unloaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
-
- If you specify "*" as the cog, every cog currently loaded will be unloaded, and then every cog present in the
- bot/cogs directory will be loaded.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog == "*":
- full_cog = cog
- elif cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog == "*":
- all_cogs = [
- f"bot.cogs.{fn[:-3]}" for fn in os.listdir("bot/cogs")
- if os.path.isfile(f"bot/cogs/{fn}") and fn.endswith(".py") and "_" not in fn
- ]
-
- failed_unloads = {}
- failed_loads = {}
-
- unloaded = 0
- loaded = 0
-
- for loaded_cog in self.bot.extensions.copy().keys():
- try:
- self.bot.unload_extension(loaded_cog)
- except Exception as e:
- failed_unloads[loaded_cog] = f"{e.__class__.__name__}: {e}"
- else:
- unloaded += 1
-
- for unloaded_cog in all_cogs:
- try:
- self.bot.load_extension(unloaded_cog)
- except Exception as e:
- failed_loads[unloaded_cog] = f"{e.__class__.__name__}: {e}"
- else:
- loaded += 1
-
- lines = [
- "**All cogs reloaded**",
- f"**Unloaded**: {unloaded} / **Loaded**: {loaded}"
- ]
-
- if failed_unloads:
- lines.append("\n**Unload failures**")
-
- for cog, error in failed_unloads:
- lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`")
-
- if failed_loads:
- lines.append("\n**Load failures**")
-
- for cog, error in failed_loads.items():
- lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`")
-
- log.debug(f"{ctx.author} requested we reload all cogs. Here are the results: \n"
- f"{lines}")
-
- await LinePaginator.paginate(lines, ctx, embed, empty=False)
- return
-
- elif full_cog in self.bot.extensions:
- try:
- self.bot.unload_extension(full_cog)
- self.bot.load_extension(full_cog)
- except Exception as e:
- log.exception(f"{ctx.author} requested we reload the '{cog}' cog, "
- "but the unloading failed")
- embed.description = f"Failed to reload cog: {cog}\n\n```{e}```"
- else:
- log.debug(f"{ctx.author} requested we reload the '{cog}' cog. Cog reloaded!")
- embed.description = f"Cog reload: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but the cog wasn't loaded!")
- embed.description = f"Cog {cog} is not loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='list', aliases=('all',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def list_command(self, ctx: Context) -> None:
- """
- Get a list of all cogs, including their loaded status.
-
- Gray indicates that the cog is unloaded. Green indicates that the cog is currently loaded.
- """
- embed = Embed()
- lines = []
- cogs = {}
-
- embed.colour = Colour.blurple()
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- for key, _value in self.cogs.items():
- if "." not in key:
- continue
-
- if key in self.bot.extensions:
- cogs[key] = True
- else:
- cogs[key] = False
-
- for key in self.bot.extensions.keys():
- if key not in self.cogs:
- cogs[key] = True
-
- for cog, loaded in sorted(cogs.items(), key=lambda x: x[0]):
- if cog in self.cogs:
- cog = self.cogs[cog]
-
- if loaded:
- status = Emojis.status_online
- else:
- status = Emojis.status_offline
-
- lines.append(f"{status} {cog}")
-
- log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
- await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
-
-
-def setup(bot: Bot) -> None:
- """Cogs cog load."""
- bot.add_cog(Cogs(bot))
- log.info("Cog loaded: Cogs")
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
new file mode 100644
index 000000000..bb66e0b8e
--- /dev/null
+++ b/bot/cogs/extensions.py
@@ -0,0 +1,236 @@
+import functools
+import logging
+import typing as t
+from enum import Enum
+from pkgutil import iter_modules
+
+from discord import Colour, Embed
+from discord.ext import commands
+from discord.ext.commands import Bot, Context, group
+
+from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
+from bot.pagination import LinePaginator
+from bot.utils.checks import with_role_check
+
+log = logging.getLogger(__name__)
+
+UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"}
+EXTENSIONS = frozenset(
+ ext.name
+ for ext in iter_modules(("bot/cogs",), "bot.cogs.")
+ if ext.name[-1] != "_"
+)
+
+
+class Action(Enum):
+ """Represents an action to perform on an extension."""
+
+ # Need to be partial otherwise they are considered to be function definitions.
+ LOAD = functools.partial(Bot.load_extension)
+ UNLOAD = functools.partial(Bot.unload_extension)
+ RELOAD = functools.partial(Bot.reload_extension)
+
+
+class Extension(commands.Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if argument == "*" or argument == "**":
+ return argument
+
+ argument = argument.lower()
+
+ if "." not in argument:
+ argument = f"bot.cogs.{argument}"
+
+ if argument in EXTENSIONS:
+ return argument
+ else:
+ raise commands.BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
+class Extensions(commands.Cog):
+ """Extension management commands."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True)
+ async def extensions_group(self, ctx: Context) -> None:
+ """Load, unload, reload, and list loaded extensions."""
+ await ctx.invoke(self.bot.get_command("help"), "extensions")
+
+ @extensions_group.command(name="load", aliases=("l",))
+ async def load_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Load extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions load")
+ return
+
+ if "*" in extensions or "**" in extensions:
+ extensions = set(EXTENSIONS) - set(self.bot.extensions.keys())
+
+ msg = self.batch_manage(Action.LOAD, *extensions)
+ await ctx.send(msg)
+
+ @extensions_group.command(name="unload", aliases=("ul",))
+ async def unload_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Unload currently loaded extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions unload")
+ return
+
+ blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))
+
+ if blacklisted:
+ msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```"
+ else:
+ if "*" in extensions or "**" in extensions:
+ extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST
+
+ msg = self.batch_manage(Action.UNLOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="reload", aliases=("r",))
+ async def reload_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Reload extensions given their fully qualified or unqualified names.
+
+ If an extension fails to be reloaded, it will be rolled-back to the prior working state.
+
+ If '\*' is given as the name, all currently loaded extensions will be reloaded.
+ If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions reload")
+ return
+
+ if "**" in extensions:
+ extensions = EXTENSIONS
+ elif "*" in extensions:
+ extensions = set(self.bot.extensions.keys()) | set(extensions)
+ extensions.remove("*")
+
+ msg = self.batch_manage(Action.RELOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="list", aliases=("all",))
+ async def list_command(self, ctx: Context) -> None:
+ """
+ Get a list of all extensions, including their loaded status.
+
+ Grey indicates that the extension is unloaded.
+ Green indicates that the extension is currently loaded.
+ """
+ embed = Embed()
+ lines = []
+
+ embed.colour = Colour.blurple()
+ embed.set_author(
+ name="Extensions List",
+ url=URLs.github_bot_repo,
+ icon_url=URLs.bot_avatar
+ )
+
+ for ext in sorted(list(EXTENSIONS)):
+ if ext in self.bot.extensions:
+ status = Emojis.status_online
+ else:
+ status = Emojis.status_offline
+
+ ext = ext.rsplit(".", 1)[1]
+ lines.append(f"{status} {ext}")
+
+ log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
+ await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
+
+ def batch_manage(self, action: Action, *extensions: str) -> str:
+ """
+ Apply an action to multiple extensions and return a message with the results.
+
+ If only one extension is given, it is deferred to `manage()`.
+ """
+ if len(extensions) == 1:
+ msg, _ = self.manage(action, extensions[0])
+ return msg
+
+ verb = action.name.lower()
+ failures = {}
+
+ for extension in extensions:
+ _, error = self.manage(action, extension)
+ if error:
+ failures[extension] = error
+
+ emoji = ":x:" if failures else ":ok_hand:"
+ msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed."
+
+ if failures:
+ failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items())
+ msg += f"\nFailures:```{failures}```"
+
+ log.debug(f"Batch {verb}ed extensions.")
+
+ return msg
+
+ def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
+ """Apply an action to an extension and return the status message and any error message."""
+ verb = action.name.lower()
+ error_msg = None
+
+ try:
+ action.value(self.bot, ext)
+ except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
+ if action is Action.RELOAD:
+ # When reloading, just load the extension if it was not loaded.
+ return self.manage(Action.LOAD, ext)
+
+ msg = f":x: Extension `{ext}` is already {verb}ed."
+ log.debug(msg[4:])
+ except Exception as e:
+ if hasattr(e, "original"):
+ e = e.original
+
+ log.exception(f"Extension '{ext}' failed to {verb}.")
+
+ error_msg = f"{e.__class__.__name__}: {e}"
+ msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```"
+ else:
+ msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`."
+ log.debug(msg[10:])
+
+ return msg, error_msg
+
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators and core developers to invoke the commands in this cog."""
+ return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer)
+
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Handle BadArgument errors locally to prevent the help command from showing."""
+ if isinstance(error, commands.BadArgument):
+ await ctx.send(str(error))
+ error.handled = True
+
+
+def setup(bot: Bot) -> None:
+ """Load the Extensions cog."""
+ bot.add_cog(Extensions(bot))
+ log.info("Cog loaded: Extensions")