diff options
| author | 2019-08-05 23:41:55 +0800 | |
|---|---|---|
| committer | 2019-08-06 00:23:51 +0800 | |
| commit | 999aac742abe475048b4322ba6c849b0dc1d82df (patch) | |
| tree | 35961b17f7f48b55016f71ba4a3e588a655d4d0d /bot/decorators.py | |
| parent | Add constant groups to `constants.py` (diff) | |
Split in_channel's predicate and check, add bypass_roles functionality
Separate the predicate function from `commands.check` to allow the predicate check to be added to the bot.
Diffstat (limited to 'bot/decorators.py')
| -rw-r--r-- | bot/decorators.py | 62 | 
1 files changed, 54 insertions, 8 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index dfe80e5c..f556660e 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,18 +1,25 @@  import logging  import random +import typing  from asyncio import Lock  from functools import wraps  from weakref import WeakValueDictionary  from discord import Colour, Embed  from discord.ext import commands -from discord.ext.commands import Context +from discord.ext.commands import CheckFailure, Context  from bot.constants import ERROR_REPLIES  log = logging.getLogger(__name__) +class InChannelCheckFailure(CheckFailure): +    """Check failure when the user runs a command in a non-whitelisted channel.""" + +    pass + +  def with_role(*role_ids: int):      """Check to see whether the invoking user has any of the roles specified in role_ids."""      async def predicate(ctx: Context): @@ -48,14 +55,53 @@ def without_role(*role_ids: int):      return commands.check(predicate) -def in_channel(channel_id): -    """Check that the command invocation is in the channel specified by channel_id.""" -    async def predicate(ctx: Context): -        check = ctx.channel.id == channel_id +def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]: +    """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" +    def predicate(ctx: Context) -> bool: +        if not ctx.guild: +            log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") +            return True + +        if ctx.channel.id in channels: +            log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " +                      f"The command was used in a whitelisted channel.") +            return True + +        if hasattr(ctx.command.callback, "in_channel_override"): +            log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " +                      f"The command was not used in a whitelisted channel, " +                      f"but the command was whitelisted to bypass the in_channel check.") +            return True + +        if bypass_roles: +            if any(r.id in bypass_roles for r in ctx.author.roles): +                log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " +                          f"The command was not used in a whitelisted channel, " +                          f"but the author had a role to bypass the in_channel check.") +                return True +          log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                  f"The result of the in_channel check was {check}.") -        return check -    return commands.check(predicate) +                  f"The in_channel check failed.") + +        channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) +        raise InChannelCheckFailure( +            f"Sorry, but you may only use this command within {channels_str}." +        ) + +    return predicate + + +in_channel = commands.check(in_channel_check) + + +def override_in_channel(func: typing.Callable) -> typing.Callable: +    """ +    Set command callback attribute for detection in `in_channel_check`. + +    This decorator has to go before (below) below the `command` decorator. +    """ +    func.in_channel_override = True +    return func  def locked(): | 
