aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils/decorators.py')
-rw-r--r--bot/utils/decorators.py110
1 files changed, 61 insertions, 49 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 9cdaad3f..66b30b97 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -13,6 +13,7 @@ from discord.ext.commands import CheckFailure, Command, Context
from bot.constants import ERROR_REPLIES, Month
from bot.utils import human_months, resolve_current_month
+from bot.utils.checks import in_whitelist_check
ONE_DAY = 24 * 60 * 60
@@ -186,82 +187,93 @@ def without_role(*role_ids: int) -> t.Callable:
return commands.check(predicate)
-def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]:
+def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], bool]:
"""
- Checks that the message is in a whitelisted channel or optionally has a bypass role.
+ Checks if a message is sent in a whitelisted context.
- If `in_channel_override` is present, check if it contains channels
- and use them in place of the global whitelist.
+ All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently".
+ If `whitelist_override` is present, it is added to the global whitelist.
"""
def predicate(ctx: Context) -> bool:
+ # Skip DM invocations
if not ctx.guild:
log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
return True
- if ctx.channel.id in channels:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in a whitelisted channel."
- )
- return True
- if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles):
- log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command and "
- f"had a role to bypass the in_channel check."
- )
- return True
+ kwargs = default_kwargs.copy()
- if hasattr(ctx.command.callback, "in_channel_override"):
- override = ctx.command.callback.in_channel_override
- if override is None:
+ # Update kwargs based on override
+ if hasattr(ctx.command.callback, "override"):
+ # Remove default kwargs if reset is True
+ if ctx.command.callback.override_reset:
+ kwargs = {}
log.debug(
- f"{ctx.author} called the '{ctx.command.name}' command "
- f"and the command was whitelisted to bypass the in_channel check."
+ f"{ctx.author} called the '{ctx.command.name}' command and "
+ f"overrode default checks."
)
- return True
- else:
- if ctx.channel.id in override:
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command "
- f"and the command was used in an overridden whitelisted channel."
- )
- return True
- log.debug(
- f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The overridden in_channel check failed."
- )
- channels_str = ', '.join(f"<#{c_id}>" for c_id in override)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
- )
+ # Merge overwrites and defaults
+ for arg in ctx.command.callback.override:
+ default_value = kwargs.get(arg)
+ new_value = ctx.command.callback.override[arg]
+
+ # Skip values that don't need merging, or can't be merged
+ if default_value is None or isinstance(arg, int):
+ kwargs[arg] = new_value
+
+ # Merge containers
+ elif isinstance(default_value, t.Container):
+ if isinstance(new_value, t.Container):
+ kwargs[arg] = (*default_value, *new_value)
+ else:
+ kwargs[arg] = new_value
+
+ log.debug(
+ f"Updated default check arguments for '{ctx.command.name}' "
+ f"invoked by {ctx.author}."
+ )
+
+ log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
+ result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
+
+ # Return if check passed
+ if result:
+ log.debug(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command "
+ f"and the command was used in an overridden context."
+ )
+ return result
log.debug(
f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The in_channel check failed."
- )
-
- channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
- raise InChannelCheckFailure(
- f"Sorry, but you may only use this command within {channels_str}."
+ f"The whitelist check failed."
)
- return predicate
+ # Raise error if the check did not pass
+ channels = kwargs.get("channels")
+ if channels:
+ channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ message = f"Sorry, but you may only use this command within {channels_str}."
+ else:
+ message = "Sorry, but you may not use this command."
+ raise InChannelCheckFailure(message)
-in_channel = commands.check(in_channel_check)
+ return predicate
-def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable:
+def whitelist_override(bypass_defaults: bool = False, **kwargs: t.Container[int]) -> t.Callable:
"""
- Set command callback attribute for detection in `in_channel_check`.
+ Override global whitelist context, with the kwargs specified.
- Override global whitelist if channels are specified.
+ All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.
+ Set `bypass_defaults` to True if you want to completely bypass global checks.
This decorator has to go before (below) below the `command` decorator.
"""
def inner(func: t.Callable) -> t.Callable:
- func.in_channel_override = channels
+ func.override = kwargs
+ func.override_reset = bypass_defaults
return func
return inner