| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
 | import asyncio
import functools
import logging
import random
import typing as t
from asyncio import Lock
from functools import wraps
from weakref import WeakValueDictionary
from discord import Colour, Embed
from discord.ext import commands
from discord.ext.commands import CheckFailure, Command, Context
from bot.constants import Client, ERROR_REPLIES, Month
from bot.utils import resolve_current_month
ONE_DAY = 24 * 60 * 60
log = logging.getLogger(__name__)
class InChannelCheckFailure(CheckFailure):
    """Check failure when the user runs a command in a non-whitelisted channel."""
    pass
class InMonthCheckFailure(CheckFailure):
    """Check failure for when a command is invoked outside of its allowed month."""
    pass
def seasonal_task(*allowed_months: Month, sleep_time: t.Union[float, int] = ONE_DAY) -> t.Callable:
    """
    Perform the decorated method periodically in `allowed_months`.
    This provides a convenience wrapper to avoid code repetition where some task shall
    perform an operation repeatedly in a constant interval, but only in specific months.
    The decorated function will be called once every `sleep_time` seconds while
    the current UTC month is in `allowed_months`. Sleep time defaults to 24 hours.
    The wrapped task is responsible for waiting for the bot to be ready, if necessary.
    """
    def decorator(task_body: t.Callable) -> t.Callable:
        @functools.wraps(task_body)
        async def decorated_task(*args, **kwargs) -> None:
            """Call `task_body` once every `sleep_time` seconds in `allowed_months`."""
            log.info(f"Starting seasonal task {task_body.__qualname__} ({allowed_months})")
            while True:
                current_month = resolve_current_month()
                if current_month in allowed_months:
                    await task_body(*args, **kwargs)
                else:
                    log.debug(f"Seasonal task {task_body.__qualname__} sleeps in {current_month.name}")
                await asyncio.sleep(sleep_time)
        return decorated_task
    return decorator
def in_month_listener(*allowed_months: Month) -> t.Callable:
    """
    Shield a listener from being invoked outside of `allowed_months`.
    The check is performed against current UTC month.
    """
    def decorator(listener: t.Callable) -> t.Callable:
        @functools.wraps(listener)
        async def guarded_listener(*args, **kwargs) -> None:
            """Wrapped listener will abort if not in allowed month."""
            current_month = resolve_current_month()
            if current_month in allowed_months:
                # Propagate return value although it should always be None
                return await listener(*args, **kwargs)
            else:
                log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month.name}")
        return guarded_listener
    return decorator
def in_month_command(*allowed_months: Month) -> t.Callable:
    """
    Check whether the command was invoked in one of `enabled_months`.
    Uses the current UTC month at the time of running the predicate.
    """
    async def predicate(ctx: Context) -> bool:
        current_month = resolve_current_month()
        can_run = current_month in allowed_months
        human_months = ", ".join(m.name for m in allowed_months)
        log.debug(
            f"Command '{ctx.command}' is locked to months {human_months}. "
            f"Invoking it in month {current_month} is {'allowed' if can_run else 'disallowed'}."
        )
        if can_run:
            return True
        else:
            raise InMonthCheckFailure(f"Command can only be used in {human_months}")
    return commands.check(predicate)
def in_month(*allowed_months: Month) -> t.Callable:
    """
    Universal decorator for season-locking commands and listeners alike.
    This only serves to determine whether the decorated callable is a command,
    a listener, or neither. It then delegates to either `in_month_command`,
    or `in_month_listener`, or raises TypeError, respectively.
    Please note that in order for this decorator to correctly determine whether
    the decorated callable is a cmd or listener, it **has** to first be turned
    into one. This means that this decorator should always be placed **above**
    the d.py one that registers it as either.
    This will decorate groups as well, as those subclass Command. In order to lock
    all subcommands of a group, its `invoke_without_command` param must **not** be
    manually set to True - this causes a circumvention of the group's callback
    and the seasonal check applied to it.
    """
    def decorator(callable_: t.Callable) -> t.Callable:
        # Functions decorated as commands are turned into instances of `Command`
        if isinstance(callable_, Command):
            logging.debug(f"Command {callable_.qualified_name} will be locked to {allowed_months}")
            actual_deco = in_month_command(*allowed_months)
        # D.py will assign this attribute when `callable_` is registered as a listener
        elif hasattr(callable_, "__cog_listener__"):
            logging.debug(f"Listener {callable_.__qualname__} will be locked to {allowed_months}")
            actual_deco = in_month_listener(*allowed_months)
        # Otherwise we're unsure exactly what has been decorated
        # This happens before the bot starts, so let's just raise
        else:
            raise TypeError(f"Decorated object {callable_} is neither a command nor a listener")
        return actual_deco(callable_)
    return decorator
def with_role(*role_ids: int) -> t.Callable:
    """Check to see whether the invoking user has any of the roles specified in role_ids."""
    async def predicate(ctx: Context) -> bool:
        if not ctx.guild:  # Return False in a DM
            log.debug(
                f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
                "This command is restricted by the with_role decorator. Rejecting request."
            )
            return False
        for role in ctx.author.roles:
            if role.id in role_ids:
                log.debug(f"{ctx.author} has the '{role.name}' role, and passes the check.")
                return True
        log.debug(
            f"{ctx.author} does not have the required role to use "
            f"the '{ctx.command.name}' command, so the request is rejected."
        )
        return False
    return commands.check(predicate)
def without_role(*role_ids: int) -> t.Callable:
    """Check whether the invoking user does not have all of the roles specified in role_ids."""
    async def predicate(ctx: Context) -> bool:
        if not ctx.guild:  # Return False in a DM
            log.debug(
                f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
                "This command is restricted by the without_role decorator. Rejecting request."
            )
            return False
        author_roles = [role.id for role in ctx.author.roles]
        check = all(role not in author_roles for role in role_ids)
        log.debug(
            f"{ctx.author} tried to call the '{ctx.command.name}' command. "
            f"The result of the without_role check was {check}."
        )
        return check
    return commands.check(predicate)
def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]:
    """
    Checks that the message is in a whitelisted channel or optionally has a bypass role.
    If `in_channel_override` is present, check if it contains channels
    and use them in place of the global whitelist.
    """
    def predicate(ctx: Context) -> bool:
        if not ctx.guild:
            log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
            return True
        if ctx.channel.id in channels:
            log.debug(
                f"{ctx.author} tried to call the '{ctx.command.name}' command "
                f"and the command was used in a whitelisted channel."
            )
            return True
        if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
            log.debug(
                f"{ctx.author} called the '{ctx.command.name}' command and "
                f"had a role to bypass the in_channel check."
            )
            return True
        if hasattr(ctx.command.callback, "in_channel_override"):
            override = ctx.command.callback.in_channel_override
            if override is None:
                log.debug(
                    f"{ctx.author} called the '{ctx.command.name}' command "
                    f"and the command was whitelisted to bypass the in_channel check."
                )
                return True
            else:
                if ctx.channel.id in override:
                    log.debug(
                        f"{ctx.author} tried to call the '{ctx.command.name}' command "
                        f"and the command was used in an overridden whitelisted channel."
                    )
                    return True
                log.debug(
                    f"{ctx.author} tried to call the '{ctx.command.name}' command. "
                    f"The overridden in_channel check failed."
                )
                channels_str = ', '.join(f"<#{c_id}>" for c_id in override)
                raise InChannelCheckFailure(
                    f"Sorry, but you may only use this command within {channels_str}."
                )
        log.debug(
            f"{ctx.author} tried to call the '{ctx.command.name}' command. "
            f"The in_channel check failed."
        )
        channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
        raise InChannelCheckFailure(
            f"Sorry, but you may only use this command within {channels_str}."
        )
    return predicate
in_channel = commands.check(in_channel_check)
def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable:
    """
    Set command callback attribute for detection in `in_channel_check`.
    Override global whitelist if channels are specified.
    This decorator has to go before (below) below the `command` decorator.
    """
    def inner(func: t.Callable) -> t.Callable:
        func.in_channel_override = channels
        return func
    return inner
def locked() -> t.Union[t.Callable, None]:
    """
    Allows the user to only run one instance of the decorated command at a time.
    Subsequent calls to the command from the same author are ignored until the command has completed invocation.
    This decorator has to go before (below) the `command` decorator.
    """
    def wrap(func: t.Callable) -> t.Union[t.Callable, None]:
        func.__locks = WeakValueDictionary()
        @wraps(func)
        async def inner(self: t.Callable, ctx: Context, *args, **kwargs) -> t.Union[t.Callable, None]:
            lock = func.__locks.setdefault(ctx.author.id, Lock())
            if lock.locked():
                embed = Embed()
                embed.colour = Colour.red()
                log.debug(f"User tried to invoke a locked command.")
                embed.description = (
                    "You're already using this command. Please wait until "
                    "it is done before you use it again."
                )
                embed.title = random.choice(ERROR_REPLIES)
                await ctx.send(embed=embed)
                return
            async with func.__locks.setdefault(ctx.author.id, Lock()):
                return await func(self, ctx, *args, **kwargs)
        return inner
    return wrap
def mock_in_debug(return_value: t.Any) -> t.Callable:
    """
    Short-circuit function execution if in debug mode and return `return_value`.
    The original function name, and the incoming args and kwargs are DEBUG level logged
    upon each call. This is useful for expensive operations, i.e. media asset uploads
    that are prone to rate-limits but need to be tested extensively.
    """
    def decorator(func: t.Callable) -> t.Callable:
        @functools.wraps(func)
        async def wrapped(*args, **kwargs) -> t.Any:
            """Short-circuit and log if in debug mode."""
            if Client.debug:
                log.debug(f"Function {func.__name__} called with args: {args}, kwargs: {kwargs}")
                return return_value
            return await func(*args, **kwargs)
        return wrapped
    return decorator
 |