aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/decorators.py62
1 files changed, 54 insertions, 8 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index dfe80e5c..f556660e 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,18 +1,25 @@
import logging
import random
+import typing
from asyncio import Lock
from functools import wraps
from weakref import WeakValueDictionary
from discord import Colour, Embed
from discord.ext import commands
-from discord.ext.commands import Context
+from discord.ext.commands import CheckFailure, Context
from bot.constants import ERROR_REPLIES
log = logging.getLogger(__name__)
+class InChannelCheckFailure(CheckFailure):
+ """Check failure when the user runs a command in a non-whitelisted channel."""
+
+ pass
+
+
def with_role(*role_ids: int):
"""Check to see whether the invoking user has any of the roles specified in role_ids."""
async def predicate(ctx: Context):
@@ -48,14 +55,53 @@ def without_role(*role_ids: int):
return commands.check(predicate)
-def in_channel(channel_id):
- """Check that the command invocation is in the channel specified by channel_id."""
- async def predicate(ctx: Context):
- check = ctx.channel.id == channel_id
+def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]:
+ """Checks that the message is in a whitelisted channel or optionally has a bypass role."""
+ def predicate(ctx: Context) -> bool:
+ if not ctx.guild:
+ log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
+ return True
+
+ if ctx.channel.id in channels:
+ log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The command was used in a whitelisted channel.")
+ return True
+
+ if hasattr(ctx.command.callback, "in_channel_override"):
+ log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The command was not used in a whitelisted channel, "
+ f"but the command was whitelisted to bypass the in_channel check.")
+ return True
+
+ if bypass_roles:
+ if any(r.id in bypass_roles for r in ctx.author.roles):
+ log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The command was not used in a whitelisted channel, "
+ f"but the author had a role to bypass the in_channel check.")
+ return True
+
log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the in_channel check was {check}.")
- return check
- return commands.check(predicate)
+ f"The in_channel check failed.")
+
+ channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ raise InChannelCheckFailure(
+ f"Sorry, but you may only use this command within {channels_str}."
+ )
+
+ return predicate
+
+
+in_channel = commands.check(in_channel_check)
+
+
+def override_in_channel(func: typing.Callable) -> typing.Callable:
+ """
+ Set command callback attribute for detection in `in_channel_check`.
+
+ This decorator has to go before (below) below the `command` decorator.
+ """
+ func.in_channel_override = True
+ return func
def locked():