aboutsummaryrefslogtreecommitdiffstats
path: root/bot/decorators.py
diff options
context:
space:
mode:
authorGravatar Sebastiaan Zeeff <[email protected]>2019-10-02 17:16:44 +0200
committerGravatar GitHub <[email protected]>2019-10-02 17:16:44 +0200
commit1974164355a0dabc884b67a93e3bf12e0ed76b11 (patch)
treedeaed1b2caf7f650f05e12613cdff5b8a12629c9 /bot/decorators.py
parentPoint setup guide to site wiki (diff)
parentMerge branch 'master' into hacktober-date-channel-fixes (diff)
Merge pull request #285 from Numerlor/hacktober-date-channel-fixes
Hacktober date range and channel whitelist fixes
Diffstat (limited to 'bot/decorators.py')
-rw-r--r--bot/decorators.py47
1 files changed, 37 insertions, 10 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index dbaad4a2..2c042b56 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -64,12 +64,16 @@ def without_role(*role_ids: int) -> bool:
def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]:
- """Checks that the message is in a whitelisted channel or optionally has a bypass role."""
+ """
+ Checks that the message is in a whitelisted channel or optionally has a bypass role.
+
+ If `in_channel_override` is present, check if it contains channels
+ and use them in place of the global whitelist.
+ """
def predicate(ctx: Context) -> bool:
if not ctx.guild:
log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
return True
-
if ctx.channel.id in channels:
log.debug(
f"{ctx.author} tried to call the '{ctx.command.name}' command "
@@ -78,11 +82,29 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)
return True
if hasattr(ctx.command.callback, "in_channel_override"):
- log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command "
- f"and the command was whitelisted to bypass the in_channel check."
- )
- return True
+ override = ctx.command.callback.in_channel_override
+ if override is None:
+ log.debug(
+ f"{ctx.author} called the '{ctx.command.name}' command "
+ f"and the command was whitelisted to bypass the in_channel check."
+ )
+ return True
+ else:
+ if ctx.channel.id in override:
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command "
+ f"and the command was used in an overridden whitelisted channel."
+ )
+ return True
+
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The overridden in_channel check failed."
+ )
+ channels_str = ', '.join(f"<#{c_id}>" for c_id in override)
+ raise InChannelCheckFailure(
+ f"Sorry, but you may only use this command within {channels_str}."
+ )
if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
log.debug(
@@ -107,14 +129,19 @@ def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None)
in_channel = commands.check(in_channel_check)
-def override_in_channel(func: typing.Callable) -> typing.Callable:
+def override_in_channel(channels: typing.Tuple[int] = None) -> typing.Callable:
"""
Set command callback attribute for detection in `in_channel_check`.
+ Override global whitelist if channels are specified.
+
This decorator has to go before (below) below the `command` decorator.
"""
- func.in_channel_override = True
- return func
+ def inner(func: typing.Callable) -> typing.Callable:
+ func.in_channel_override = channels
+ return func
+
+ return inner
def locked() -> typing.Union[typing.Callable, None]: