diff options
| -rw-r--r-- | bot/__main__.py | 2 | ||||
| -rw-r--r-- | bot/cogs/alias.py | 7 | ||||
| -rw-r--r-- | bot/cogs/cogs.py | 298 | ||||
| -rw-r--r-- | bot/cogs/extensions.py | 236 | 
4 files changed, 241 insertions, 302 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index d0924be78..19a7e5ec6 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -42,7 +42,7 @@ bot.load_extension("bot.cogs.security")  bot.load_extension("bot.cogs.antispam")  bot.load_extension("bot.cogs.bot")  bot.load_extension("bot.cogs.clean") -bot.load_extension("bot.cogs.cogs") +bot.load_extension("bot.cogs.extensions")  bot.load_extension("bot.cogs.help")  # Only load this in production diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 0f49a400c..6648805e9 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -5,6 +5,7 @@ from typing import Union  from discord import Colour, Embed, Member, User  from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from bot.cogs.extensions import Extension  from bot.cogs.watchchannels.watchchannel import proxy_user  from bot.converters import TagNameConverter  from bot.pagination import LinePaginator @@ -84,9 +85,9 @@ class Alias (Cog):          await self.invoke(ctx, "site rules")      @command(name="reload", hidden=True) -    async def cogs_reload_alias(self, ctx: Context, *, cog_name: str) -> None: -        """Alias for invoking <prefix>cogs reload [cog_name].""" -        await self.invoke(ctx, "cogs reload", cog_name) +    async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: +        """Alias for invoking <prefix>extensions reload [extensions...].""" +        await self.invoke(ctx, "extensions reload", *extensions)      @command(name="defon", hidden=True)      async def defcon_enable_alias(self, ctx: Context) -> None: diff --git a/bot/cogs/cogs.py b/bot/cogs/cogs.py deleted file mode 100644 index 1f6ccd09c..000000000 --- a/bot/cogs/cogs.py +++ /dev/null @@ -1,298 +0,0 @@ -import logging -import os - -from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group - -from bot.constants import ( -    Emojis, MODERATION_ROLES, Roles, URLs -) -from bot.decorators import with_role -from bot.pagination import LinePaginator - -log = logging.getLogger(__name__) - -KEEP_LOADED = ["bot.cogs.cogs", "bot.cogs.modlog"] - - -class Cogs(Cog): -    """Cog management commands.""" - -    def __init__(self, bot: Bot): -        self.bot = bot -        self.cogs = {} - -        # Load up the cog names -        log.info("Initializing cog names...") -        for filename in os.listdir("bot/cogs"): -            if filename.endswith(".py") and "_" not in filename: -                if os.path.isfile(f"bot/cogs/{filename}"): -                    cog = filename[:-3] - -                    self.cogs[cog] = f"bot.cogs.{cog}" - -        # Allow reverse lookups by reversing the pairs -        self.cogs.update({v: k for k, v in self.cogs.items()}) - -    @group(name='cogs', aliases=('c',), invoke_without_command=True) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def cogs_group(self, ctx: Context) -> None: -        """Load, unload, reload, and list active cogs.""" -        await ctx.invoke(self.bot.get_command("help"), "cogs") - -    @cogs_group.command(name='load', aliases=('l',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def load_command(self, ctx: Context, cog: str) -> None: -        """ -        Load up an unloaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we load the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog not in self.bot.extensions: -                try: -                    self.bot.load_extension(full_cog) -                except ImportError: -                    log.exception(f"{ctx.author} requested we load the '{cog}' cog, " -                                  f"but the cog module {full_cog} could not be found!") -                    embed.description = f"Invalid cog: {cog}\n\nCould not find cog module {full_cog}" -                except Exception as e: -                    log.exception(f"{ctx.author} requested we load the '{cog}' cog, " -                                  "but the loading failed") -                    embed.description = f"Failed to load cog: {cog}\n\n{e.__class__.__name__}: {e}" -                else: -                    log.debug(f"{ctx.author} requested we load the '{cog}' cog. Cog loaded!") -                    embed.description = f"Cog loaded: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we load the '{cog}' cog, but the cog was already loaded!") -                embed.description = f"Cog {cog} is already loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='unload', aliases=('ul',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def unload_command(self, ctx: Context, cog: str) -> None: -        """ -        Unload an already-loaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog in KEEP_LOADED: -                log.warning(f"{ctx.author} requested we unload `{full_cog}`, that sneaky pete. We said no.") -                embed.description = f"You may not unload `{full_cog}`!" -            elif full_cog in self.bot.extensions: -                try: -                    self.bot.unload_extension(full_cog) -                except Exception as e: -                    log.exception(f"{ctx.author} requested we unload the '{cog}' cog, " -                                  "but the unloading failed") -                    embed.description = f"Failed to unload cog: {cog}\n\n```{e}```" -                else: -                    log.debug(f"{ctx.author} requested we unload the '{cog}' cog. Cog unloaded!") -                    embed.description = f"Cog unloaded: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but the cog wasn't loaded!") -                embed.description = f"Cog {cog} is not loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='reload', aliases=('r',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def reload_command(self, ctx: Context, cog: str) -> None: -        """ -        Reload an unloaded cog, given the module containing it. - -        You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the -        entire module directly. - -        If you specify "*" as the cog, every cog currently loaded will be unloaded, and then every cog present in the -        bot/cogs directory will be loaded. -        """ -        cog = cog.lower() - -        embed = Embed() -        embed.colour = Colour.red() - -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        if cog == "*": -            full_cog = cog -        elif cog in self.cogs: -            full_cog = self.cogs[cog] -        elif "." in cog: -            full_cog = cog -        else: -            full_cog = None -            log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but that cog doesn't exist.") -            embed.description = f"Unknown cog: {cog}" - -        if full_cog: -            if full_cog == "*": -                all_cogs = [ -                    f"bot.cogs.{fn[:-3]}" for fn in os.listdir("bot/cogs") -                    if os.path.isfile(f"bot/cogs/{fn}") and fn.endswith(".py") and "_" not in fn -                ] - -                failed_unloads = {} -                failed_loads = {} - -                unloaded = 0 -                loaded = 0 - -                for loaded_cog in self.bot.extensions.copy().keys(): -                    try: -                        self.bot.unload_extension(loaded_cog) -                    except Exception as e: -                        failed_unloads[loaded_cog] = f"{e.__class__.__name__}: {e}" -                    else: -                        unloaded += 1 - -                for unloaded_cog in all_cogs: -                    try: -                        self.bot.load_extension(unloaded_cog) -                    except Exception as e: -                        failed_loads[unloaded_cog] = f"{e.__class__.__name__}: {e}" -                    else: -                        loaded += 1 - -                lines = [ -                    "**All cogs reloaded**", -                    f"**Unloaded**: {unloaded} / **Loaded**: {loaded}" -                ] - -                if failed_unloads: -                    lines.append("\n**Unload failures**") - -                    for cog, error in failed_unloads: -                        lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`") - -                if failed_loads: -                    lines.append("\n**Load failures**") - -                    for cog, error in failed_loads.items(): -                        lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`") - -                log.debug(f"{ctx.author} requested we reload all cogs. Here are the results: \n" -                          f"{lines}") - -                await LinePaginator.paginate(lines, ctx, embed, empty=False) -                return - -            elif full_cog in self.bot.extensions: -                try: -                    self.bot.unload_extension(full_cog) -                    self.bot.load_extension(full_cog) -                except Exception as e: -                    log.exception(f"{ctx.author} requested we reload the '{cog}' cog, " -                                  "but the unloading failed") -                    embed.description = f"Failed to reload cog: {cog}\n\n```{e}```" -                else: -                    log.debug(f"{ctx.author} requested we reload the '{cog}' cog. Cog reloaded!") -                    embed.description = f"Cog reload: {cog}" -                    embed.colour = Colour.green() -            else: -                log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but the cog wasn't loaded!") -                embed.description = f"Cog {cog} is not loaded" - -        await ctx.send(embed=embed) - -    @cogs_group.command(name='list', aliases=('all',)) -    @with_role(*MODERATION_ROLES, Roles.core_developer) -    async def list_command(self, ctx: Context) -> None: -        """ -        Get a list of all cogs, including their loaded status. - -        Gray indicates that the cog is unloaded. Green indicates that the cog is currently loaded. -        """ -        embed = Embed() -        lines = [] -        cogs = {} - -        embed.colour = Colour.blurple() -        embed.set_author( -            name="Python Bot (Cogs)", -            url=URLs.github_bot_repo, -            icon_url=URLs.bot_avatar -        ) - -        for key, _value in self.cogs.items(): -            if "." not in key: -                continue - -            if key in self.bot.extensions: -                cogs[key] = True -            else: -                cogs[key] = False - -        for key in self.bot.extensions.keys(): -            if key not in self.cogs: -                cogs[key] = True - -        for cog, loaded in sorted(cogs.items(), key=lambda x: x[0]): -            if cog in self.cogs: -                cog = self.cogs[cog] - -            if loaded: -                status = Emojis.status_online -            else: -                status = Emojis.status_offline - -            lines.append(f"{status}  {cog}") - -        log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") -        await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) - - -def setup(bot: Bot) -> None: -    """Cogs cog load.""" -    bot.add_cog(Cogs(bot)) -    log.info("Cog loaded: Cogs") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py new file mode 100644 index 000000000..bb66e0b8e --- /dev/null +++ b/bot/cogs/extensions.py @@ -0,0 +1,236 @@ +import functools +import logging +import typing as t +from enum import Enum +from pkgutil import iter_modules + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Bot, Context, group + +from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + +UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"} +EXTENSIONS = frozenset( +    ext.name +    for ext in iter_modules(("bot/cogs",), "bot.cogs.") +    if ext.name[-1] != "_" +) + + +class Action(Enum): +    """Represents an action to perform on an extension.""" + +    # Need to be partial otherwise they are considered to be function definitions. +    LOAD = functools.partial(Bot.load_extension) +    UNLOAD = functools.partial(Bot.unload_extension) +    RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if "." not in argument: +            argument = f"bot.cogs.{argument}" + +        if argument in EXTENSIONS: +            return argument +        else: +            raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): +    """Extension management commands.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) +    async def extensions_group(self, ctx: Context) -> None: +        """Load, unload, reload, and list loaded extensions.""" +        await ctx.invoke(self.bot.get_command("help"), "extensions") + +    @extensions_group.command(name="load", aliases=("l",)) +    async def load_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Load extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions load") +            return + +        if "*" in extensions or "**" in extensions: +            extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + +        msg = self.batch_manage(Action.LOAD, *extensions) +        await ctx.send(msg) + +    @extensions_group.command(name="unload", aliases=("ul",)) +    async def unload_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Unload currently loaded extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions unload") +            return + +        blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + +        if blacklisted: +            msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" +        else: +            if "*" in extensions or "**" in extensions: +                extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + +            msg = self.batch_manage(Action.UNLOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="reload", aliases=("r",)) +    async def reload_command(self, ctx: Context, *extensions: Extension) -> None: +        """ +        Reload extensions given their fully qualified or unqualified names. + +        If an extension fails to be reloaded, it will be rolled-back to the prior working state. + +        If '\*' is given as the name, all currently loaded extensions will be reloaded. +        If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.invoke(self.bot.get_command("help"), "extensions reload") +            return + +        if "**" in extensions: +            extensions = EXTENSIONS +        elif "*" in extensions: +            extensions = set(self.bot.extensions.keys()) | set(extensions) +            extensions.remove("*") + +        msg = self.batch_manage(Action.RELOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="list", aliases=("all",)) +    async def list_command(self, ctx: Context) -> None: +        """ +        Get a list of all extensions, including their loaded status. + +        Grey indicates that the extension is unloaded. +        Green indicates that the extension is currently loaded. +        """ +        embed = Embed() +        lines = [] + +        embed.colour = Colour.blurple() +        embed.set_author( +            name="Extensions List", +            url=URLs.github_bot_repo, +            icon_url=URLs.bot_avatar +        ) + +        for ext in sorted(list(EXTENSIONS)): +            if ext in self.bot.extensions: +                status = Emojis.status_online +            else: +                status = Emojis.status_offline + +            ext = ext.rsplit(".", 1)[1] +            lines.append(f"{status}  {ext}") + +        log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") +        await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False) + +    def batch_manage(self, action: Action, *extensions: str) -> str: +        """ +        Apply an action to multiple extensions and return a message with the results. + +        If only one extension is given, it is deferred to `manage()`. +        """ +        if len(extensions) == 1: +            msg, _ = self.manage(action, extensions[0]) +            return msg + +        verb = action.name.lower() +        failures = {} + +        for extension in extensions: +            _, error = self.manage(action, extension) +            if error: +                failures[extension] = error + +        emoji = ":x:" if failures else ":ok_hand:" +        msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + +        if failures: +            failures = "\n".join(f"{ext}\n    {err}" for ext, err in failures.items()) +            msg += f"\nFailures:```{failures}```" + +        log.debug(f"Batch {verb}ed extensions.") + +        return msg + +    def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: +        """Apply an action to an extension and return the status message and any error message.""" +        verb = action.name.lower() +        error_msg = None + +        try: +            action.value(self.bot, ext) +        except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): +            if action is Action.RELOAD: +                # When reloading, just load the extension if it was not loaded. +                return self.manage(Action.LOAD, ext) + +            msg = f":x: Extension `{ext}` is already {verb}ed." +            log.debug(msg[4:]) +        except Exception as e: +            if hasattr(e, "original"): +                e = e.original + +            log.exception(f"Extension '{ext}' failed to {verb}.") + +            error_msg = f"{e.__class__.__name__}: {e}" +            msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" +        else: +            msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." +            log.debug(msg[10:]) + +        return msg, error_msg + +    # This cannot be static (must have a __func__ attribute). +    def cog_check(self, ctx: Context) -> bool: +        """Only allow moderators and core developers to invoke the commands in this cog.""" +        return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer) + +    # This cannot be static (must have a __func__ attribute). +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Handle BadArgument errors locally to prevent the help command from showing.""" +        if isinstance(error, commands.BadArgument): +            await ctx.send(str(error)) +            error.handled = True + + +def setup(bot: Bot) -> None: +    """Load the Extensions cog.""" +    bot.add_cog(Extensions(bot)) +    log.info("Cog loaded: Extensions") | 
