diff options
| author | 2019-10-03 02:11:58 +1000 | |
|---|---|---|
| committer | 2019-10-03 02:11:58 +1000 | |
| commit | 72b5fc1df29e0861f0306a76879ee057f607f531 (patch) | |
| tree | b69e3776ea1dce7641d523ad0657fdf9455de61f /bot/decorators.py | |
| parent | Applied suggestions from code review (diff) | |
| parent | Merge pull request #285 from Numerlor/hacktober-date-channel-fixes (diff) | |
Merge branch 'master' into trivia_quiz
Diffstat (limited to 'bot/decorators.py')
| -rw-r--r-- | bot/decorators.py | 47 | 
1 files changed, 37 insertions, 10 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index dbaad4a2..2c042b56 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -64,12 +64,16 @@ def without_role(*role_ids: int) -> bool:  def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]: -    """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" +    """ +    Checks that the message is in a whitelisted channel or optionally has a bypass role. + +    If `in_channel_override` is present, check if it contains channels +    and use them in place of the global whitelist. +    """      def predicate(ctx: Context) -> bool:          if not ctx.guild:              log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")              return True -          if ctx.channel.id in channels:              log.debug(                  f"{ctx.author} tried to call the '{ctx.command.name}' command " @@ -78,11 +82,29 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)              return True          if hasattr(ctx.command.callback, "in_channel_override"): -            log.debug( -                f"{ctx.author} called the '{ctx.command.name}' command " -                f"and the command was whitelisted to bypass the in_channel check." -            ) -            return True +            override = ctx.command.callback.in_channel_override +            if override is None: +                log.debug( +                    f"{ctx.author} called the '{ctx.command.name}' command " +                    f"and the command was whitelisted to bypass the in_channel check." +                ) +                return True +            else: +                if ctx.channel.id in override: +                    log.debug( +                        f"{ctx.author} tried to call the '{ctx.command.name}' command " +                        f"and the command was used in an overridden whitelisted channel." +                    ) +                    return True + +                log.debug( +                    f"{ctx.author} tried to call the '{ctx.command.name}' command. " +                    f"The overridden in_channel check failed." +                ) +                channels_str = ', '.join(f"<#{c_id}>" for c_id in override) +                raise InChannelCheckFailure( +                    f"Sorry, but you may only use this command within {channels_str}." +                )          if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):              log.debug( @@ -107,14 +129,19 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)  in_channel = commands.check(in_channel_check) -def override_in_channel(func: typing.Callable) -> typing.Callable: +def override_in_channel(channels: typing.Tuple[int] = None) -> typing.Callable:      """      Set command callback attribute for detection in `in_channel_check`. +    Override global whitelist if channels are specified. +      This decorator has to go before (below) below the `command` decorator.      """ -    func.in_channel_override = True -    return func +    def inner(func: typing.Callable) -> typing.Callable: +        func.in_channel_override = channels +        return func + +    return inner  def locked() -> typing.Union[typing.Callable, None]: | 
