diff options
Diffstat (limited to 'bot/utils')
| -rw-r--r-- | bot/utils/checks.py | 164 | ||||
| -rw-r--r-- | bot/utils/converters.py | 16 | ||||
| -rw-r--r-- | bot/utils/decorators.py | 24 | ||||
| -rw-r--r-- | bot/utils/exceptions.py | 6 | ||||
| -rw-r--r-- | bot/utils/extensions.py | 34 | ||||
| -rw-r--r-- | bot/utils/pagination.py | 4 | ||||
| -rw-r--r-- | bot/utils/persist.py | 66 | ||||
| -rw-r--r-- | bot/utils/randomization.py | 23 | 
8 files changed, 241 insertions, 96 deletions
| diff --git a/bot/utils/checks.py b/bot/utils/checks.py new file mode 100644 index 00000000..9dd4dde0 --- /dev/null +++ b/bot/utils/checks.py @@ -0,0 +1,164 @@ +import datetime +import logging +from typing import Callable, Container, Iterable, Optional + +from discord.ext.commands import ( +    BucketType, +    CheckFailure, +    Cog, +    Command, +    CommandOnCooldown, +    Context, +    Cooldown, +    CooldownMapping, +) + +from bot import constants + +log = logging.getLogger(__name__) + + +class InWhitelistCheckFailure(CheckFailure): +    """Raised when the `in_whitelist` check fails.""" + +    def __init__(self, redirect_channel: Optional[int]) -> None: +        self.redirect_channel = redirect_channel + +        if redirect_channel: +            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +        else: +            redirect_message = "" + +        error_message = f"You are not allowed to use that command{redirect_message}." + +        super().__init__(error_message) + + +def in_whitelist_check( +    ctx: Context, +    channels: Container[int] = (), +    categories: Container[int] = (), +    roles: Container[int] = (), +    redirect: Optional[int] = constants.Channels.community_bot_commands, +    fail_silently: bool = False, +) -> bool: +    """ +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: + +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    If the command was invoked in a context that was not whitelisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). +    """ +    if redirect and redirect not in channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if applicable). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        channels = tuple(channels) + (redirect,) + +    if channels and ctx.channel.id in channels: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") +        return True + +    # Only check the category id if we have a category whitelist and the channel has a `category_id` +    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +        return True + +    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +        return True + +    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + +    # Some commands are secret, and should produce no feedback at all. +    if not fail_silently: +        raise InWhitelistCheckFailure(redirect) +    return False + + +def with_role_check(ctx: Context, *role_ids: int) -> bool: +    """Returns True if the user has any one of the roles in role_ids.""" +    if not ctx.guild:  # Return False in a DM +        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " +                  "This command is restricted by the with_role decorator. Rejecting request.") +        return False + +    for role in ctx.author.roles: +        if role.id in role_ids: +            log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") +            return True + +    log.trace(f"{ctx.author} does not have the required role to use " +              f"the '{ctx.command.name}' command, so the request is rejected.") +    return False + + +def without_role_check(ctx: Context, *role_ids: int) -> bool: +    """Returns True if the user does not have any of the roles in role_ids.""" +    if not ctx.guild:  # Return False in a DM +        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " +                  "This command is restricted by the without_role decorator. Rejecting request.") +        return False + +    author_roles = [role.id for role in ctx.author.roles] +    check = all(role not in author_roles for role in role_ids) +    log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " +              f"The result of the without_role check was {check}.") +    return check + + +def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, +                              bypass_roles: Iterable[int]) -> Callable: +    """ +    Applies a cooldown to a command, but allows members with certain roles to be ignored. + +    NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future. +    """ +    # Make it a set so lookup is hash based. +    bypass = set(bypass_roles) + +    # This handles the actual cooldown logic. +    buckets = CooldownMapping(Cooldown(rate, per, type)) + +    # Will be called after the command has been parse but before it has been invoked, ensures that +    # the cooldown won't be updated if the user screws up their input to the command. +    async def predicate(cog: Cog, ctx: Context) -> None: +        nonlocal bypass, buckets + +        if any(role.id in bypass for role in ctx.author.roles): +            return + +        # Cooldown logic, taken from discord.py internals. +        current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp() +        bucket = buckets.get_bucket(ctx.message) +        retry_after = bucket.update_rate_limit(current) +        if retry_after: +            raise CommandOnCooldown(bucket, retry_after) + +    def wrapper(command: Command) -> Command: +        # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it +        # so I just made it raise an error when the decorator is applied before the actual command object exists. +        # +        # If the `before_invoke` detail is ever a problem then I can quickly just swap over. +        if not isinstance(command, Command): +            raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. ' +                            'This means it has to be above the command decorator in the code.') + +        command._before_invoke = predicate + +        return command + +    return wrapper diff --git a/bot/utils/converters.py b/bot/utils/converters.py new file mode 100644 index 00000000..228714c9 --- /dev/null +++ b/bot/utils/converters.py @@ -0,0 +1,16 @@ +import discord +from discord.ext.commands.converter import MessageConverter + + +class WrappedMessageConverter(MessageConverter): +    """A converter that handles embed-suppressed links like <http://example.com>.""" + +    async def convert(self, ctx: discord.ext.commands.Context, argument: str) -> discord.Message: +        """Wrap the commands.MessageConverter to handle <> delimited message links.""" +        # It's possible to wrap a message in [<>] as well, and it's supported because its easy +        if argument.startswith("[") and argument.endswith("]"): +            argument = argument[1:-1] +        if argument.startswith("<") and argument.endswith(">"): +            argument = argument[1:-1] + +        return await super().convert(ctx, argument) diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 519e61a9..9cdaad3f 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -11,7 +11,7 @@ from discord import Colour, Embed  from discord.ext import commands  from discord.ext.commands import CheckFailure, Command, Context -from bot.constants import Client, ERROR_REPLIES, Month +from bot.constants import ERROR_REPLIES, Month  from bot.utils import human_months, resolve_current_month  ONE_DAY = 24 * 60 * 60 @@ -285,7 +285,7 @@ def locked() -> t.Union[t.Callable, None]:                  embed = Embed()                  embed.colour = Colour.red() -                log.debug(f"User tried to invoke a locked command.") +                log.debug("User tried to invoke a locked command.")                  embed.description = (                      "You're already using this command. Please wait until "                      "it is done before you use it again." @@ -298,23 +298,3 @@ def locked() -> t.Union[t.Callable, None]:                  return await func(self, ctx, *args, **kwargs)          return inner      return wrap - - -def mock_in_debug(return_value: t.Any) -> t.Callable: -    """ -    Short-circuit function execution if in debug mode and return `return_value`. - -    The original function name, and the incoming args and kwargs are DEBUG level logged -    upon each call. This is useful for expensive operations, i.e. media asset uploads -    that are prone to rate-limits but need to be tested extensively. -    """ -    def decorator(func: t.Callable) -> t.Callable: -        @functools.wraps(func) -        async def wrapped(*args, **kwargs) -> t.Any: -            """Short-circuit and log if in debug mode.""" -            if Client.debug: -                log.debug(f"Function {func.__name__} called with args: {args}, kwargs: {kwargs}") -                return return_value -            return await func(*args, **kwargs) -        return wrapped -    return decorator diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index dc62debe..2b1c1b31 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -1,9 +1,3 @@ -class BrandingError(Exception): -    """Exception raised by the BrandingManager cog.""" - -    pass - -  class UserNotPlayingError(Exception):      """Will raised when user try to use game commands when not playing.""" diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py new file mode 100644 index 00000000..50350ea8 --- /dev/null +++ b/bot/utils/extensions.py @@ -0,0 +1,34 @@ +import importlib +import inspect +import pkgutil +from typing import Iterator, NoReturn + +from bot import exts + + +def unqualify(name: str) -> str: +    """Return an unqualified name given a qualified module/package `name`.""" +    return name.rsplit(".", maxsplit=1)[-1] + + +def walk_extensions() -> Iterator[str]: +    """Yield extension names from the bot.exts subpackage.""" + +    def on_error(name: str) -> NoReturn: +        raise ImportError(name=name)  # pragma: no cover + +    for module in pkgutil.walk_packages(exts.__path__, f"{exts.__name__}.", onerror=on_error): +        if unqualify(module.name).startswith("_"): +            # Ignore module/package names starting with an underscore. +            continue + +        if module.ispkg: +            imported = importlib.import_module(module.name) +            if not inspect.isfunction(getattr(imported, "setup", None)): +                # If it lacks a setup function, it's not an extension. +                continue + +        yield module.name + + +EXTENSIONS = frozenset(walk_extensions()) diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index 9a7a0382..a4d0cc56 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -128,7 +128,7 @@ class LinePaginator(Paginator):          if not lines:              if exception_on_empty_embed: -                log.exception(f"Pagination asked for empty lines iterable") +                log.exception("Pagination asked for empty lines iterable")                  raise EmptyPaginatorEmbed("No lines to paginate")              log.debug("No lines to add to paginator, adding '(nothing to display)' message") @@ -335,7 +335,7 @@ class ImagePaginator(Paginator):          if not pages:              if exception_on_empty_embed: -                log.exception(f"Pagination asked for empty image list") +                log.exception("Pagination asked for empty image list")                  raise EmptyPaginatorEmbed("No images to paginate")              log.debug("No images to add to paginator, adding '(no images to display)' message") diff --git a/bot/utils/persist.py b/bot/utils/persist.py deleted file mode 100644 index d78e5420..00000000 --- a/bot/utils/persist.py +++ /dev/null @@ -1,66 +0,0 @@ -import sqlite3 -from pathlib import Path -from shutil import copyfile - -from bot.exts import get_package_names - -DIRECTORY = Path("data")  # directory that has a persistent volume mapped to it - - -def make_persistent(file_path: Path) -> Path: -    """ -    Copy datafile at the provided file_path to the persistent data directory. - -    A persistent data file is needed by some features in order to not lose data -    after bot rebuilds. - -    This function will ensure that a clean data file with default schema, -    structure or data is copied over to the persistent volume before returning -    the path to this new persistent version of the file. - -    If the persistent file already exists, it won't be overwritten with the -    clean default file, just returning the Path instead to the existing file. - -    Note: Avoid using the same file name as other features in the same seasons -    as otherwise only one datafile can be persistent and will be returned for -    both cases. - -    Example Usage: -    >>> import json -    >>> template_datafile = Path("bot", "resources", "evergreen", "myfile.json") -    >>> path_to_persistent_file = make_persistent(template_datafile) -    >>> print(path_to_persistent_file) -    data/evergreen/myfile.json -    >>> with path_to_persistent_file.open("w+") as f: -    >>>     data = json.load(f) -    """ -    # ensure the persistent data directory exists -    DIRECTORY.mkdir(exist_ok=True) - -    if not file_path.is_file(): -        raise OSError(f"File not found at {file_path}.") - -    # detect season in datafile path for assigning to subdirectory -    season = next((s for s in get_package_names() if s in file_path.parts), None) - -    if season: -        # make sure subdirectory exists first -        subdirectory = Path(DIRECTORY, season) -        subdirectory.mkdir(exist_ok=True) - -        persistent_path = Path(subdirectory, file_path.name) - -    else: -        persistent_path = Path(DIRECTORY, file_path.name) - -    # copy base/template datafile to persistent directory -    if not persistent_path.exists(): -        copyfile(file_path, persistent_path) - -    return persistent_path - - -def sqlite(db_path: Path) -> sqlite3.Connection: -    """Copy sqlite file to the persistent data directory and return an open connection.""" -    persistent_path = make_persistent(db_path) -    return sqlite3.connect(persistent_path) diff --git a/bot/utils/randomization.py b/bot/utils/randomization.py new file mode 100644 index 00000000..8f47679a --- /dev/null +++ b/bot/utils/randomization.py @@ -0,0 +1,23 @@ +import itertools +import random +import typing as t + + +class RandomCycle: +    """ +    Cycles through elements from a randomly shuffled iterable, repeating indefinitely. + +    The iterable is reshuffled after each full cycle. +    """ + +    def __init__(self, iterable: t.Iterable) -> None: +        self.iterable = list(iterable) +        self.index = itertools.cycle(range(len(iterable))) + +    def __next__(self) -> t.Any: +        idx = next(self.index) + +        if idx == 0: +            random.shuffle(self.iterable) + +        return self.iterable[idx] | 
