aboutsummaryrefslogtreecommitdiffstats
path: root/bot/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/decorators.py')
-rw-r--r--bot/decorators.py107
1 files changed, 83 insertions, 24 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index dfe80e5c..dbaad4a2 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,24 +1,33 @@
import logging
import random
+import typing
from asyncio import Lock
from functools import wraps
from weakref import WeakValueDictionary
from discord import Colour, Embed
from discord.ext import commands
-from discord.ext.commands import Context
+from discord.ext.commands import CheckFailure, Context
from bot.constants import ERROR_REPLIES
log = logging.getLogger(__name__)
-def with_role(*role_ids: int):
+class InChannelCheckFailure(CheckFailure):
+ """Check failure when the user runs a command in a non-whitelisted channel."""
+
+ pass
+
+
+def with_role(*role_ids: int) -> bool:
"""Check to see whether the invoking user has any of the roles specified in role_ids."""
- async def predicate(ctx: Context):
+ async def predicate(ctx: Context) -> bool:
if not ctx.guild: # Return False in a DM
- log.debug(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
- "This command is restricted by the with_role decorator. Rejecting request.")
+ log.debug(
+ f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
+ "This command is restricted by the with_role decorator. Rejecting request."
+ )
return False
for role in ctx.author.roles:
@@ -26,39 +35,89 @@ def with_role(*role_ids: int):
log.debug(f"{ctx.author} has the '{role.name}' role, and passes the check.")
return True
- log.debug(f"{ctx.author} does not have the required role to use "
- f"the '{ctx.command.name}' command, so the request is rejected.")
+ log.debug(
+ f"{ctx.author} does not have the required role to use "
+ f"the '{ctx.command.name}' command, so the request is rejected."
+ )
return False
return commands.check(predicate)
-def without_role(*role_ids: int):
+def without_role(*role_ids: int) -> bool:
"""Check whether the invoking user does not have all of the roles specified in role_ids."""
- async def predicate(ctx: Context):
+ async def predicate(ctx: Context) -> bool:
if not ctx.guild: # Return False in a DM
- log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
- "This command is restricted by the without_role decorator. Rejecting request.")
+ log.debug(
+ f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
+ "This command is restricted by the without_role decorator. Rejecting request."
+ )
return False
author_roles = [role.id for role in ctx.author.roles]
check = all(role not in author_roles for role in role_ids)
- log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the without_role check was {check}.")
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The result of the without_role check was {check}."
+ )
return check
return commands.check(predicate)
-def in_channel(channel_id):
- """Check that the command invocation is in the channel specified by channel_id."""
- async def predicate(ctx: Context):
- check = ctx.channel.id == channel_id
- log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the in_channel check was {check}.")
- return check
- return commands.check(predicate)
+def in_channel_check(*channels: int, bypass_roles: typing.Container[int] = None) -> typing.Callable[[Context], bool]:
+ """Checks that the message is in a whitelisted channel or optionally has a bypass role."""
+ def predicate(ctx: Context) -> bool:
+ if not ctx.guild:
+ log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
+ return True
+
+ if ctx.channel.id in channels:
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command "
+ f"and the command was used in a whitelisted channel."
+ )
+ return True
+
+ if hasattr(ctx.command.callback, "in_channel_override"):
+ log.debug(
+ f"{ctx.author} called the '{ctx.command.name}' command "
+ f"and the command was whitelisted to bypass the in_channel check."
+ )
+ return True
+
+ if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
+ log.debug(
+ f"{ctx.author} called the '{ctx.command.name}' command and "
+ f"had a role to bypass the in_channel check."
+ )
+ return True
+
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The in_channel check failed."
+ )
+
+ channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ raise InChannelCheckFailure(
+ f"Sorry, but you may only use this command within {channels_str}."
+ )
+
+ return predicate
+
+
+in_channel = commands.check(in_channel_check)
+
+
+def override_in_channel(func: typing.Callable) -> typing.Callable:
+ """
+ Set command callback attribute for detection in `in_channel_check`.
+
+ This decorator has to go before (below) below the `command` decorator.
+ """
+ func.in_channel_override = True
+ return func
-def locked():
+def locked() -> typing.Union[typing.Callable, None]:
"""
Allows the user to only run one instance of the decorated command at a time.
@@ -66,11 +125,11 @@ def locked():
This decorator has to go before (below) the `command` decorator.
"""
- def wrap(func):
+ def wrap(func: typing.Callable) -> typing.Union[typing.Callable, None]:
func.__locks = WeakValueDictionary()
@wraps(func)
- async def inner(self, ctx, *args, **kwargs):
+ async def inner(self: typing.Callable, ctx: Context, *args, **kwargs) -> typing.Union[typing.Callable, None]:
lock = func.__locks.setdefault(ctx.author.id, Lock())
if lock.locked():
embed = Embed()