diff options
| -rw-r--r-- | pydis_core/utils/checks.py | 187 | 
1 files changed, 187 insertions, 0 deletions
| diff --git a/pydis_core/utils/checks.py b/pydis_core/utils/checks.py new file mode 100644 index 00000000..809e98de --- /dev/null +++ b/pydis_core/utils/checks.py @@ -0,0 +1,187 @@ +import datetime +from collections.abc import Callable, Container, Iterable + +from discord.ext.commands import ( +    BucketType, +    CheckFailure, +    Cog, +    Command, +    CommandOnCooldown, +    Context, +    Cooldown, +    CooldownMapping, +) + +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) + + +class ContextCheckFailure(CheckFailure): +    """Raised when a context-specific check fails.""" + +    def __init__(self, redirect_channel: int | None) -> None: +        self.redirect_channel = redirect_channel + +        if redirect_channel: +            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +        else: +            redirect_message = "" + +        error_message = f"You are not allowed to use that command{redirect_message}." + +        super().__init__(error_message) + + +class InWhitelistCheckFailure(ContextCheckFailure): +    """Raised when the `in_whitelist` check fails.""" + + +def in_whitelist_check( +    ctx: Context, +    redirect: int, +    channels: Container[int] = (), +    categories: Container[int] = (), +    roles: Container[int] = (), +    fail_silently: bool = False, +) -> bool: +    """ +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: + +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    If the command was invoked in a context that was not whitelisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). +    """ +    if redirect not in channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if applicable). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        channels = tuple(channels) + (redirect,) + +    if channels and ctx.channel.id in channels: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") +        return True + +    # Only check the category id if we have a category whitelist and the channel has a `category_id` +    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +        return True + +    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +        return True + +    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + +    # Some commands are secret, and should produce no feedback at all. +    if not fail_silently: +        raise InWhitelistCheckFailure(redirect) +    return False + + +def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, +                              bypass_roles: Iterable[int]) -> Callable: +    """ +    Applies a cooldown to a command, but allows members with certain roles to be ignored. + +    NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future. +    """ +    # Make it a set so lookup is hash based. +    bypass = set(bypass_roles) + +    # This handles the actual cooldown logic. +    buckets = CooldownMapping(Cooldown(rate, per, type)) + +    # Will be called after the command has been parse but before it has been invoked, ensures that +    # the cooldown won't be updated if the user screws up their input to the command. +    async def predicate(cog: Cog, ctx: Context) -> None: +        nonlocal bypass, buckets + +        if any(role.id in bypass for role in ctx.author.roles): +            return + +        # Cooldown logic, taken from discord.py internals. +        current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp() +        bucket = buckets.get_bucket(ctx.message) +        retry_after = bucket.update_rate_limit(current) +        if retry_after: +            raise CommandOnCooldown(bucket, retry_after) + +    def wrapper(command: Command) -> Command: +        # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it +        # so I just made it raise an error when the decorator is applied before the actual command object exists. +        # +        # If the `before_invoke` detail is ever a problem then I can quickly just swap over. +        if not isinstance(command, Command): +            raise TypeError( +                "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. " +                "This means it has to be above the command decorator in the code." +            ) + +        command._before_invoke = predicate + +        return command + +    return wrapper + + +async def has_any_role_check(ctx: Context, *roles: str | int) -> bool: +    """ +    Returns True if the context's author has any of the specified roles. + +    `roles` are the names or IDs of the roles for which to check. +    False is always returns if the context is outside a guild. +    """ +    if not ctx.guild:  # Return False in a DM +        log.trace( +            f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " +            "This command is restricted by the with_role decorator. Rejecting request." +        ) +        return False + +    for role in ctx.author.roles: +        if role.id in roles: +            log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") +            return True + +    log.trace( +        f"{ctx.author} does not have the required role to use " +        f"the '{ctx.command.name}' command, so the request is rejected." +    ) +    return False + + +async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool: +    """ +    Returns True if the context's author doesn't have any of the specified roles. + +    `roles` are the names or IDs of the roles for which to check. +    False is always returns if the context is outside a guild. +    """ +    if not ctx.guild:  # Return False in a DM +        log.trace( +            f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " +            "This command is restricted by the without_role decorator. Rejecting request." +        ) +        return False + +    author_roles = [role.id for role in ctx.author.roles] +    check = all(role not in author_roles for role in roles) +    log.trace( +        f"{ctx.author} tried to call the '{ctx.command.name}' command. " +        f"The result of the without_role check was {check}." +    ) +    return check | 
