aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/decorators.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 132aaa87..7a3b14ad 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -196,15 +196,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
If `whitelist_override` is present, it is added to the global whitelist.
"""
def predicate(ctx: Context) -> bool:
- # Skip DM invocations
- if not ctx.guild:
- log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
- return True
-
kwargs = default_kwargs.copy()
+ allow_dms = False
# Update kwargs based on override
if hasattr(ctx.command.callback, "override"):
+ # Handle DM invocations
+ allow_dms = ctx.command.callback.override_dm
+
# Remove default kwargs if reset is True
if ctx.command.callback.override_reset:
kwargs = {}
@@ -234,8 +233,12 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
f"invoked by {ctx.author}."
)
- log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
- result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
+ if ctx.guild is None:
+ log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.")
+ result = allow_dms
+ else:
+ log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
+ result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
# Return if check passed
if result:
@@ -260,8 +263,8 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
default_whitelist_channels.discard(Channels.community_bot_commands)
channels.difference_update(default_whitelist_channels)
- # Add all whitelisted category channels
- if categories:
+ # Add all whitelisted category channels, but skip if we're in DMs
+ if categories and ctx.guild is not None:
for category_id in categories:
category = ctx.guild.get_channel(category_id)
if category is None:
@@ -280,18 +283,22 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
return predicate
-def whitelist_override(bypass_defaults: bool = False, **kwargs: Container[int]) -> Callable:
+def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable:
"""
Override global whitelist context, with the kwargs specified.
All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.
Set `bypass_defaults` to True if you want to completely bypass global checks.
+ Set `allow_dm` to True if you want to allow the command to be invoked from within direct messages.
+ Note that you have to be careful with any references to the guild.
+
This decorator has to go before (below) below the `command` decorator.
"""
def inner(func: Callable) -> Callable:
func.override = kwargs
func.override_reset = bypass_defaults
+ func.override_dm = allow_dm
return func
return inner