From 0073e85b0adce4de6246a2e1013e3d84808de481 Mon Sep 17 00:00:00 2001 From: Hassan Abouelela <47495861+HassanAbouelela@users.noreply.github.com> Date: Sun, 7 Feb 2021 00:39:09 +0300 Subject: Overhauls In Channel Check Upgrades in channel check to support categories, and in the case of overrides, roles too. Signed-off-by: Hassan Abouelela <47495861+HassanAbouelela@users.noreply.github.com> --- bot/utils/decorators.py | 110 +++++++++++++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 49 deletions(-) (limited to 'bot/utils/decorators.py') diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 9cdaad3f..66b30b97 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -13,6 +13,7 @@ from discord.ext.commands import CheckFailure, Command, Context from bot.constants import ERROR_REPLIES, Month from bot.utils import human_months, resolve_current_month +from bot.utils.checks import in_whitelist_check ONE_DAY = 24 * 60 * 60 @@ -186,82 +187,93 @@ def without_role(*role_ids: int) -> t.Callable: return commands.check(predicate) -def in_channel_check(*channels: int, bypass_roles: t.Container[int] = None) -> t.Callable[[Context], bool]: +def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], bool]: """ - Checks that the message is in a whitelisted channel or optionally has a bypass role. + Checks if a message is sent in a whitelisted context. - If `in_channel_override` is present, check if it contains channels - and use them in place of the global whitelist. + All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently". + If `whitelist_override` is present, it is added to the global whitelist. """ def predicate(ctx: Context) -> bool: + # Skip DM invocations if not ctx.guild: log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") return True - if ctx.channel.id in channels: - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command " - f"and the command was used in a whitelisted channel." - ) - return True - if bypass_roles and any(r.id in bypass_roles for r in ctx.author.roles): - log.debug( - f"{ctx.author} called the '{ctx.command.name}' command and " - f"had a role to bypass the in_channel check." - ) - return True + kwargs = default_kwargs.copy() - if hasattr(ctx.command.callback, "in_channel_override"): - override = ctx.command.callback.in_channel_override - if override is None: + # Update kwargs based on override + if hasattr(ctx.command.callback, "override"): + # Remove default kwargs if reset is True + if ctx.command.callback.override_reset: + kwargs = {} log.debug( - f"{ctx.author} called the '{ctx.command.name}' command " - f"and the command was whitelisted to bypass the in_channel check." + f"{ctx.author} called the '{ctx.command.name}' command and " + f"overrode default checks." ) - return True - else: - if ctx.channel.id in override: - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command " - f"and the command was used in an overridden whitelisted channel." - ) - return True - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The overridden in_channel check failed." - ) - channels_str = ', '.join(f"<#{c_id}>" for c_id in override) - raise InChannelCheckFailure( - f"Sorry, but you may only use this command within {channels_str}." - ) + # Merge overwrites and defaults + for arg in ctx.command.callback.override: + default_value = kwargs.get(arg) + new_value = ctx.command.callback.override[arg] + + # Skip values that don't need merging, or can't be merged + if default_value is None or isinstance(arg, int): + kwargs[arg] = new_value + + # Merge containers + elif isinstance(default_value, t.Container): + if isinstance(new_value, t.Container): + kwargs[arg] = (*default_value, *new_value) + else: + kwargs[arg] = new_value + + log.debug( + f"Updated default check arguments for '{ctx.command.name}' " + f"invoked by {ctx.author}." + ) + + log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") + result = in_whitelist_check(ctx, fail_silently=True, **kwargs) + + # Return if check passed + if result: + log.debug( + f"{ctx.author} tried to call the '{ctx.command.name}' command " + f"and the command was used in an overridden context." + ) + return result log.debug( f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The in_channel check failed." - ) - - channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) - raise InChannelCheckFailure( - f"Sorry, but you may only use this command within {channels_str}." + f"The whitelist check failed." ) - return predicate + # Raise error if the check did not pass + channels = kwargs.get("channels") + if channels: + channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) + message = f"Sorry, but you may only use this command within {channels_str}." + else: + message = "Sorry, but you may not use this command." + raise InChannelCheckFailure(message) -in_channel = commands.check(in_channel_check) + return predicate -def override_in_channel(channels: t.Tuple[int] = None) -> t.Callable: +def whitelist_override(bypass_defaults: bool = False, **kwargs: t.Container[int]) -> t.Callable: """ - Set command callback attribute for detection in `in_channel_check`. + Override global whitelist context, with the kwargs specified. - Override global whitelist if channels are specified. + All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`. + Set `bypass_defaults` to True if you want to completely bypass global checks. This decorator has to go before (below) below the `command` decorator. """ def inner(func: t.Callable) -> t.Callable: - func.in_channel_override = channels + func.override = kwargs + func.override_reset = bypass_defaults return func return inner -- cgit v1.2.3 From 73f11e22deda8bc25a03c75f3b7f380dd7f67f4e Mon Sep 17 00:00:00 2001 From: Hassan Abouelela <47495861+HassanAbouelela@users.noreply.github.com> Date: Sun, 7 Feb 2021 10:09:06 +0300 Subject: Adds Category Channels To Error Message Adds the channels within categories to the failure message of the command whitelist check. Signed-off-by: Hassan Abouelela <47495861+HassanAbouelela@users.noreply.github.com> --- bot/utils/decorators.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'bot/utils/decorators.py') diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 66b30b97..c12a15ff 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -250,7 +250,18 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context], ) # Raise error if the check did not pass - channels = kwargs.get("channels") + channels = set(kwargs.get("channels") or {}) + categories = kwargs.get("categories") + + # Add all whitelisted category channels + if categories: + for category_id in categories: + category = ctx.guild.get_channel(category_id) + if category is None: + continue + + [channels.add(channel.id) for channel in category.text_channels] + if channels: channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) message = f"Sorry, but you may only use this command within {channels_str}." -- cgit v1.2.3