diff options
Diffstat (limited to 'bot/decorators.py')
-rw-r--r-- | bot/decorators.py | 62 |
1 files changed, 54 insertions, 8 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index dfe80e5c..f556660e 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,18 +1,25 @@ import logging import random +import typing from asyncio import Lock from functools import wraps from weakref import WeakValueDictionary from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import Context +from discord.ext.commands import CheckFailure, Context from bot.constants import ERROR_REPLIES log = logging.getLogger(__name__) +class InChannelCheckFailure(CheckFailure): + """Check failure when the user runs a command in a non-whitelisted channel.""" + + pass + + def with_role(*role_ids: int): """Check to see whether the invoking user has any of the roles specified in role_ids.""" async def predicate(ctx: Context): @@ -48,14 +55,53 @@ def without_role(*role_ids: int): return commands.check(predicate) -def in_channel(channel_id): - """Check that the command invocation is in the channel specified by channel_id.""" - async def predicate(ctx: Context): - check = ctx.channel.id == channel_id +def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]: + """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" + def predicate(ctx: Context) -> bool: + if not ctx.guild: + log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") + return True + + if ctx.channel.id in channels: + log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The command was used in a whitelisted channel.") + return True + + if hasattr(ctx.command.callback, "in_channel_override"): + log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The command was not used in a whitelisted channel, " + f"but the command was whitelisted to bypass the in_channel check.") + return True + + if bypass_roles: + if any(r.id in bypass_roles for r in ctx.author.roles): + log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " + f"The command was not used in a whitelisted channel, " + f"but the author had a role to bypass the in_channel check.") + return True + log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The result of the in_channel check was {check}.") - return check - return commands.check(predicate) + f"The in_channel check failed.") + + channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) + raise InChannelCheckFailure( + f"Sorry, but you may only use this command within {channels_str}." + ) + + return predicate + + +in_channel = commands.check(in_channel_check) + + +def override_in_channel(func: typing.Callable) -> typing.Callable: + """ + Set command callback attribute for detection in `in_channel_check`. + + This decorator has to go before (below) below the `command` decorator. + """ + func.in_channel_override = True + return func def locked(): |