aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils/checks.py
blob: 809e98de337a1d3b9a2f204887c7d5256c772a9a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import datetime
from collections.abc import Callable, Container, Iterable

from discord.ext.commands import (
    BucketType,
    CheckFailure,
    Cog,
    Command,
    CommandOnCooldown,
    Context,
    Cooldown,
    CooldownMapping,
)

from pydis_core.utils.logging import get_logger

log = get_logger(__name__)


class ContextCheckFailure(CheckFailure):
    """Raised when a context-specific check fails."""

    def __init__(self, redirect_channel: int | None) -> None:
        self.redirect_channel = redirect_channel

        if redirect_channel:
            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
        else:
            redirect_message = ""

        error_message = f"You are not allowed to use that command{redirect_message}."

        super().__init__(error_message)


class InWhitelistCheckFailure(ContextCheckFailure):
    """Raised when the `in_whitelist` check fails."""


def in_whitelist_check(
    ctx: Context,
    redirect: int,
    channels: Container[int] = (),
    categories: Container[int] = (),
    roles: Container[int] = (),
    fail_silently: bool = False,
) -> bool:
    """
    Check if a command was issued in a whitelisted context.

    The whitelists that can be provided are:

    - `channels`: a container with channel ids for whitelisted channels
    - `categories`: a container with category ids for whitelisted categories
    - `roles`: a container with with role ids for whitelisted roles

    If the command was invoked in a context that was not whitelisted, the member is either
    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
    told that they're not allowed to use this particular command (if `None` was passed).
    """
    if redirect not in channels:
        # It does not make sense for the channel whitelist to not contain the redirection
        # channel (if applicable). That's why we add the redirection channel to the `channels`
        # container if it's not already in it. As we allow any container type to be passed,
        # we first create a tuple in order to safely add the redirection channel.
        #
        # Note: It's possible for the redirect channel to be in a whitelisted category, but
        # there's no easy way to check that and as a channel can easily be moved in and out of
        # categories, it's probably not wise to rely on its category in any case.
        channels = tuple(channels) + (redirect,)

    if channels and ctx.channel.id in channels:
        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
        return True

    # Only check the category id if we have a category whitelist and the channel has a `category_id`
    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
        return True

    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
        return True

    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")

    # Some commands are secret, and should produce no feedback at all.
    if not fail_silently:
        raise InWhitelistCheckFailure(redirect)
    return False


def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
                              bypass_roles: Iterable[int]) -> Callable:
    """
    Applies a cooldown to a command, but allows members with certain roles to be ignored.

    NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future.
    """
    # Make it a set so lookup is hash based.
    bypass = set(bypass_roles)

    # This handles the actual cooldown logic.
    buckets = CooldownMapping(Cooldown(rate, per, type))

    # Will be called after the command has been parse but before it has been invoked, ensures that
    # the cooldown won't be updated if the user screws up their input to the command.
    async def predicate(cog: Cog, ctx: Context) -> None:
        nonlocal bypass, buckets

        if any(role.id in bypass for role in ctx.author.roles):
            return

        # Cooldown logic, taken from discord.py internals.
        current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp()
        bucket = buckets.get_bucket(ctx.message)
        retry_after = bucket.update_rate_limit(current)
        if retry_after:
            raise CommandOnCooldown(bucket, retry_after)

    def wrapper(command: Command) -> Command:
        # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it
        # so I just made it raise an error when the decorator is applied before the actual command object exists.
        #
        # If the `before_invoke` detail is ever a problem then I can quickly just swap over.
        if not isinstance(command, Command):
            raise TypeError(
                "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. "
                "This means it has to be above the command decorator in the code."
            )

        command._before_invoke = predicate

        return command

    return wrapper


async def has_any_role_check(ctx: Context, *roles: str | int) -> bool:
    """
    Returns True if the context's author has any of the specified roles.

    `roles` are the names or IDs of the roles for which to check.
    False is always returns if the context is outside a guild.
    """
    if not ctx.guild:  # Return False in a DM
        log.trace(
            f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
            "This command is restricted by the with_role decorator. Rejecting request."
        )
        return False

    for role in ctx.author.roles:
        if role.id in roles:
            log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
            return True

    log.trace(
        f"{ctx.author} does not have the required role to use "
        f"the '{ctx.command.name}' command, so the request is rejected."
    )
    return False


async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool:
    """
    Returns True if the context's author doesn't have any of the specified roles.

    `roles` are the names or IDs of the roles for which to check.
    False is always returns if the context is outside a guild.
    """
    if not ctx.guild:  # Return False in a DM
        log.trace(
            f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
            "This command is restricted by the without_role decorator. Rejecting request."
        )
        return False

    author_roles = [role.id for role in ctx.author.roles]
    check = all(role not in author_roles for role in roles)
    log.trace(
        f"{ctx.author} tried to call the '{ctx.command.name}' command. "
        f"The result of the without_role check was {check}."
    )
    return check