diff options
| -rw-r--r-- | bot/__main__.py | 3 | ||||
| -rw-r--r-- | bot/constants.py | 7 | ||||
| -rw-r--r-- | bot/exts/__init__.py | 23 | ||||
| -rw-r--r-- | bot/exts/evergreen/snakes/__init__.py | 2 | ||||
| -rw-r--r-- | bot/exts/evergreen/snakes/_converter.py (renamed from bot/exts/evergreen/snakes/converter.py) | 2 | ||||
| -rw-r--r-- | bot/exts/evergreen/snakes/_snakes_cog.py (renamed from bot/exts/evergreen/snakes/snakes_cog.py) | 4 | ||||
| -rw-r--r-- | bot/exts/evergreen/snakes/_utils.py (renamed from bot/exts/evergreen/snakes/utils.py) | 0 | ||||
| -rw-r--r-- | bot/exts/utils/__init__.py | 0 | ||||
| -rw-r--r-- | bot/exts/utils/extensions.py | 265 | ||||
| -rw-r--r-- | bot/utils/checks.py | 164 | ||||
| -rw-r--r-- | bot/utils/extensions.py | 34 | 
11 files changed, 477 insertions, 27 deletions
| diff --git a/bot/__main__.py b/bot/__main__.py index 0ffd6143..cd2d43a9 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -5,8 +5,9 @@ from sentry_sdk.integrations.logging import LoggingIntegration  from bot.bot import bot  from bot.constants import Client, STAFF_ROLES, WHITELISTED_CHANNELS -from bot.exts import walk_extensions  from bot.utils.decorators import in_channel_check +from bot.utils.extensions import walk_extensions +  sentry_logging = LoggingIntegration(      level=logging.DEBUG, diff --git a/bot/constants.py b/bot/constants.py index 7c8f72cb..935b90e0 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -84,6 +84,7 @@ class Client(NamedTuple):      token = environ.get("SEASONALBOT_TOKEN")      sentry_dsn = environ.get("SEASONALBOT_SENTRY_DSN")      debug = environ.get("SEASONALBOT_DEBUG", "").lower() == "true" +    github_bot_repo = "https://github.com/python-discord/seasonalbot"      # Override seasonal locks: 1 (January) to 12 (December)      month_override = int(environ["MONTH_OVERRIDE"]) if "MONTH_OVERRIDE" in environ else None @@ -122,6 +123,11 @@ class Emojis:      pull_request_closed = "<:PRClosed:629695470519713818>"      merge = "<:PRMerged:629695470570176522>" +    status_online = "<:status_online:470326272351010816>" +    status_idle = "<:status_idle:470326266625785866>" +    status_dnd = "<:status_dnd:470326272082313216>" +    status_offline = "<:status_offline:470326266537705472>" +  class Hacktoberfest(NamedTuple):      voice_id = 514420006474219521 @@ -177,6 +183,7 @@ class Roles(NamedTuple):      verified = 352427296948486144      helpers = 267630620367257601      rockstars = 458226413825294336 +    core_developers = 587606783669829632  class Tokens(NamedTuple): diff --git a/bot/exts/__init__.py b/bot/exts/__init__.py index 25deb9af..13f484ac 100644 --- a/bot/exts/__init__.py +++ b/bot/exts/__init__.py @@ -1,9 +1,8 @@  import logging  import pkgutil -from pathlib import Path  from typing import Iterator -__all__ = ("get_package_names", "walk_extensions") +__all__ = ("get_package_names",)  log = logging.getLogger(__name__) @@ -13,23 +12,3 @@ def get_package_names() -> Iterator[str]:      for package in pkgutil.iter_modules(__path__):          if package.ispkg:              yield package.name - - -def walk_extensions() -> Iterator[str]: -    """ -    Iterate dot-separated paths to all extensions. - -    The strings are formatted in a way such that the bot's `load_extension` -    method can take them. Use this to load all available extensions. - -    This intentionally doesn't make use of pkgutil's `walk_packages`, as we only -    want to build paths to extensions - not recursively all modules. For some -    extensions, the `setup` function is in the package's __init__ file, while -    modules nested under the package are only helpers. Constructing the paths -    ourselves serves our purpose better. -    """ -    base_path = Path(__path__[0]) - -    for package in get_package_names(): -        for extension in pkgutil.iter_modules([base_path.joinpath(package)]): -            yield f"bot.exts.{package}.{extension.name}" diff --git a/bot/exts/evergreen/snakes/__init__.py b/bot/exts/evergreen/snakes/__init__.py index 2eae2751..bc42f0c2 100644 --- a/bot/exts/evergreen/snakes/__init__.py +++ b/bot/exts/evergreen/snakes/__init__.py @@ -2,7 +2,7 @@ import logging  from discord.ext import commands -from bot.exts.evergreen.snakes.snakes_cog import Snakes +from bot.exts.evergreen.snakes._snakes_cog import Snakes  log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/converter.py b/bot/exts/evergreen/snakes/_converter.py index 55609b8e..eee248cf 100644 --- a/bot/exts/evergreen/snakes/converter.py +++ b/bot/exts/evergreen/snakes/_converter.py @@ -7,7 +7,7 @@ import discord  from discord.ext.commands import Context, Converter  from fuzzywuzzy import fuzz -from bot.exts.evergreen.snakes.utils import SNAKE_RESOURCES +from bot.exts.evergreen.snakes._utils import SNAKE_RESOURCES  from bot.utils import disambiguate  log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/snakes_cog.py b/bot/exts/evergreen/snakes/_snakes_cog.py index 9bbad9fe..a846274b 100644 --- a/bot/exts/evergreen/snakes/snakes_cog.py +++ b/bot/exts/evergreen/snakes/_snakes_cog.py @@ -18,8 +18,8 @@ from discord import Colour, Embed, File, Member, Message, Reaction  from discord.ext.commands import BadArgument, Bot, Cog, CommandError, Context, bot_has_permissions, group  from bot.constants import ERROR_REPLIES, Tokens -from bot.exts.evergreen.snakes import utils -from bot.exts.evergreen.snakes.converter import Snake +from bot.exts.evergreen.snakes import _utils as utils +from bot.exts.evergreen.snakes._converter import Snake  from bot.utils.decorators import locked  log = logging.getLogger(__name__) diff --git a/bot/exts/evergreen/snakes/utils.py b/bot/exts/evergreen/snakes/_utils.py index 7d6caf04..7d6caf04 100644 --- a/bot/exts/evergreen/snakes/utils.py +++ b/bot/exts/evergreen/snakes/_utils.py diff --git a/bot/exts/utils/__init__.py b/bot/exts/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/bot/exts/utils/__init__.py diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py new file mode 100644 index 00000000..102a0416 --- /dev/null +++ b/bot/exts/utils/extensions.py @@ -0,0 +1,265 @@ +import functools +import logging +import typing as t +from enum import Enum + +from discord import Colour, Embed +from discord.ext import commands +from discord.ext.commands import Context, group + +from bot import exts +from bot.bot import SeasonalBot as Bot +from bot.constants import Client, Emojis, MODERATION_ROLES, Roles +from bot.utils.checks import with_role_check +from bot.utils.extensions import EXTENSIONS, unqualify +from bot.utils.pagination import LinePaginator + +log = logging.getLogger(__name__) + + +UNLOAD_BLACKLIST = {f"{exts.__name__}.utils.extensions"} +BASE_PATH_LEN = len(exts.__name__.split(".")) + + +class Action(Enum): +    """Represents an action to perform on an extension.""" + +    # Need to be partial otherwise they are considered to be function definitions. +    LOAD = functools.partial(Bot.load_extension) +    UNLOAD = functools.partial(Bot.unload_extension) +    RELOAD = functools.partial(Bot.reload_extension) + + +class Extension(commands.Converter): +    """ +    Fully qualify the name of an extension and ensure it exists. + +    The * and ** values bypass this when used with the reload command. +    """ + +    async def convert(self, ctx: Context, argument: str) -> str: +        """Fully qualify the name of an extension and ensure it exists.""" +        # Special values to reload all extensions +        if argument == "*" or argument == "**": +            return argument + +        argument = argument.lower() + +        if argument in EXTENSIONS: +            return argument +        elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS: +            return qualified_arg + +        matches = [] +        for ext in EXTENSIONS: +            if argument == unqualify(ext): +                matches.append(ext) + +        if len(matches) > 1: +            matches.sort() +            names = "\n".join(matches) +            raise commands.BadArgument( +                f":x: `{argument}` is an ambiguous extension name. " +                f"Please use one of the following fully-qualified names.```\n{names}```" +            ) +        elif matches: +            return matches[0] +        else: +            raise commands.BadArgument(f":x: Could not find the extension `{argument}`.") + + +class Extensions(commands.Cog): +    """Extension management commands.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True) +    async def extensions_group(self, ctx: Context) -> None: +        """Load, unload, reload, and list loaded extensions.""" +        await ctx.send_help(ctx.command) + +    @extensions_group.command(name="load", aliases=("l",)) +    async def load_command(self, ctx: Context, *extensions: Extension) -> None: +        r""" +        Load extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.send_help(ctx.command) +            return + +        if "*" in extensions or "**" in extensions: +            extensions = set(EXTENSIONS) - set(self.bot.extensions.keys()) + +        msg = self.batch_manage(Action.LOAD, *extensions) +        await ctx.send(msg) + +    @extensions_group.command(name="unload", aliases=("ul",)) +    async def unload_command(self, ctx: Context, *extensions: Extension) -> None: +        r""" +        Unload currently loaded extensions given their fully qualified or unqualified names. + +        If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.send_help(ctx.command) +            return + +        blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions)) + +        if blacklisted: +            msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```" +        else: +            if "*" in extensions or "**" in extensions: +                extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST + +            msg = self.batch_manage(Action.UNLOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="reload", aliases=("r",), root_aliases=("reload",)) +    async def reload_command(self, ctx: Context, *extensions: Extension) -> None: +        r""" +        Reload extensions given their fully qualified or unqualified names. + +        If an extension fails to be reloaded, it will be rolled-back to the prior working state. + +        If '\*' is given as the name, all currently loaded extensions will be reloaded. +        If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded. +        """  # noqa: W605 +        if not extensions: +            await ctx.send_help(ctx.command) +            return + +        if "**" in extensions: +            extensions = EXTENSIONS +        elif "*" in extensions: +            extensions = set(self.bot.extensions.keys()) | set(extensions) +            extensions.remove("*") + +        msg = self.batch_manage(Action.RELOAD, *extensions) + +        await ctx.send(msg) + +    @extensions_group.command(name="list", aliases=("all",)) +    async def list_command(self, ctx: Context) -> None: +        """ +        Get a list of all extensions, including their loaded status. + +        Grey indicates that the extension is unloaded. +        Green indicates that the extension is currently loaded. +        """ +        embed = Embed(colour=Colour.blurple()) +        embed.set_author( +            name="Extensions List", +            url=Client.github_bot_repo, +            icon_url=str(self.bot.user.avatar_url) +        ) + +        lines = [] +        categories = self.group_extension_statuses() +        for category, extensions in sorted(categories.items()): +            # Treat each category as a single line by concatenating everything. +            # This ensures the paginator will not cut off a page in the middle of a category. +            category = category.replace("_", " ").title() +            extensions = "\n".join(sorted(extensions)) +            lines.append(f"**{category}**\n{extensions}\n") + +        log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.") +        await LinePaginator.paginate(lines, ctx, embed, max_size=1200, empty=False) + +    def group_extension_statuses(self) -> t.Mapping[str, str]: +        """Return a mapping of extension names and statuses to their categories.""" +        categories = {} + +        for ext in EXTENSIONS: +            if ext in self.bot.extensions: +                status = Emojis.status_online +            else: +                status = Emojis.status_offline + +            path = ext.split(".") +            if len(path) > BASE_PATH_LEN + 1: +                category = " - ".join(path[BASE_PATH_LEN:-1]) +            else: +                category = "uncategorised" + +            categories.setdefault(category, []).append(f"{status}  {path[-1]}") + +        return categories + +    def batch_manage(self, action: Action, *extensions: str) -> str: +        """ +        Apply an action to multiple extensions and return a message with the results. + +        If only one extension is given, it is deferred to `manage()`. +        """ +        if len(extensions) == 1: +            msg, _ = self.manage(action, extensions[0]) +            return msg + +        verb = action.name.lower() +        failures = {} + +        for extension in extensions: +            _, error = self.manage(action, extension) +            if error: +                failures[extension] = error + +        emoji = ":x:" if failures else ":ok_hand:" +        msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed." + +        if failures: +            failures = "\n".join(f"{ext}\n    {err}" for ext, err in failures.items()) +            msg += f"\nFailures:```{failures}```" + +        log.debug(f"Batch {verb}ed extensions.") + +        return msg + +    def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]: +        """Apply an action to an extension and return the status message and any error message.""" +        verb = action.name.lower() +        error_msg = None + +        try: +            action.value(self.bot, ext) +        except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded): +            if action is Action.RELOAD: +                # When reloading, just load the extension if it was not loaded. +                return self.manage(Action.LOAD, ext) + +            msg = f":x: Extension `{ext}` is already {verb}ed." +            log.debug(msg[4:]) +        except Exception as e: +            if hasattr(e, "original"): +                e = e.original + +            log.exception(f"Extension '{ext}' failed to {verb}.") + +            error_msg = f"{e.__class__.__name__}: {e}" +            msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```" +        else: +            msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`." +            log.debug(msg[10:]) + +        return msg, error_msg + +    # This cannot be static (must have a __func__ attribute). +    def cog_check(self, ctx: Context) -> bool: +        """Only allow moderators and core developers to invoke the commands in this cog.""" +        return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) + +    # This cannot be static (must have a __func__ attribute). +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Handle BadArgument errors locally to prevent the help command from showing.""" +        if isinstance(error, commands.BadArgument): +            await ctx.send(str(error)) +            error.handled = True + + +def setup(bot: Bot) -> None: +    """Load the Extensions cog.""" +    bot.add_cog(Extensions(bot)) diff --git a/bot/utils/checks.py b/bot/utils/checks.py new file mode 100644 index 00000000..3031a271 --- /dev/null +++ b/bot/utils/checks.py @@ -0,0 +1,164 @@ +import datetime +import logging +from typing import Callable, Container, Iterable, Optional + +from discord.ext.commands import ( +    BucketType, +    CheckFailure, +    Cog, +    Command, +    CommandOnCooldown, +    Context, +    Cooldown, +    CooldownMapping, +) + +from bot import constants + +log = logging.getLogger(__name__) + + +class InWhitelistCheckFailure(CheckFailure): +    """Raised when the `in_whitelist` check fails.""" + +    def __init__(self, redirect_channel: Optional[int]) -> None: +        self.redirect_channel = redirect_channel + +        if redirect_channel: +            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +        else: +            redirect_message = "" + +        error_message = f"You are not allowed to use that command{redirect_message}." + +        super().__init__(error_message) + + +def in_whitelist_check( +    ctx: Context, +    channels: Container[int] = (), +    categories: Container[int] = (), +    roles: Container[int] = (), +    redirect: Optional[int] = constants.Channels.seasonalbot_commands, +    fail_silently: bool = False, +) -> bool: +    """ +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: + +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    If the command was invoked in a context that was not whitelisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). +    """ +    if redirect and redirect not in channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if applicable). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        channels = tuple(channels) + (redirect,) + +    if channels and ctx.channel.id in channels: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") +        return True + +    # Only check the category id if we have a category whitelist and the channel has a `category_id` +    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +        return True + +    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +        return True + +    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + +    # Some commands are secret, and should produce no feedback at all. +    if not fail_silently: +        raise InWhitelistCheckFailure(redirect) +    return False + + +def with_role_check(ctx: Context, *role_ids: int) -> bool: +    """Returns True if the user has any one of the roles in role_ids.""" +    if not ctx.guild:  # Return False in a DM +        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " +                  "This command is restricted by the with_role decorator. Rejecting request.") +        return False + +    for role in ctx.author.roles: +        if role.id in role_ids: +            log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") +            return True + +    log.trace(f"{ctx.author} does not have the required role to use " +              f"the '{ctx.command.name}' command, so the request is rejected.") +    return False + + +def without_role_check(ctx: Context, *role_ids: int) -> bool: +    """Returns True if the user does not have any of the roles in role_ids.""" +    if not ctx.guild:  # Return False in a DM +        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " +                  "This command is restricted by the without_role decorator. Rejecting request.") +        return False + +    author_roles = [role.id for role in ctx.author.roles] +    check = all(role not in author_roles for role in role_ids) +    log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " +              f"The result of the without_role check was {check}.") +    return check + + +def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, +                              bypass_roles: Iterable[int]) -> Callable: +    """ +    Applies a cooldown to a command, but allows members with certain roles to be ignored. + +    NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future. +    """ +    # Make it a set so lookup is hash based. +    bypass = set(bypass_roles) + +    # This handles the actual cooldown logic. +    buckets = CooldownMapping(Cooldown(rate, per, type)) + +    # Will be called after the command has been parse but before it has been invoked, ensures that +    # the cooldown won't be updated if the user screws up their input to the command. +    async def predicate(cog: Cog, ctx: Context) -> None: +        nonlocal bypass, buckets + +        if any(role.id in bypass for role in ctx.author.roles): +            return + +        # Cooldown logic, taken from discord.py internals. +        current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp() +        bucket = buckets.get_bucket(ctx.message) +        retry_after = bucket.update_rate_limit(current) +        if retry_after: +            raise CommandOnCooldown(bucket, retry_after) + +    def wrapper(command: Command) -> Command: +        # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it +        # so I just made it raise an error when the decorator is applied before the actual command object exists. +        # +        # If the `before_invoke` detail is ever a problem then I can quickly just swap over. +        if not isinstance(command, Command): +            raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. ' +                            'This means it has to be above the command decorator in the code.') + +        command._before_invoke = predicate + +        return command + +    return wrapper diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py new file mode 100644 index 00000000..50350ea8 --- /dev/null +++ b/bot/utils/extensions.py @@ -0,0 +1,34 @@ +import importlib +import inspect +import pkgutil +from typing import Iterator, NoReturn + +from bot import exts + + +def unqualify(name: str) -> str: +    """Return an unqualified name given a qualified module/package `name`.""" +    return name.rsplit(".", maxsplit=1)[-1] + + +def walk_extensions() -> Iterator[str]: +    """Yield extension names from the bot.exts subpackage.""" + +    def on_error(name: str) -> NoReturn: +        raise ImportError(name=name)  # pragma: no cover + +    for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): +        if unqualify(module.name).startswith("_"): +            # Ignore module/package names starting with an underscore. +            continue + +        if module.ispkg: +            imported = importlib.import_module(module.name) +            if not inspect.isfunction(getattr(imported, "setup", None)): +                # If it lacks a setup function, it's not an extension. +                continue + +        yield module.name + + +EXTENSIONS = frozenset(walk_extensions()) | 
