diff options
| author | 2020-05-27 07:32:16 +0200 | |
|---|---|---|
| committer | 2020-05-27 07:32:16 +0200 | |
| commit | d9190d997538f49c0a1b53d63a15bada3c89297f (patch) | |
| tree | e1a2db4240f97ed7c183b3fc7e1f718e313bafc7 | |
| parent | Merge pull request #866 from python-discord/restricted_tags (diff) | |
Refactor the in_whitelist deco to a check.
We're moving the actual predicate into the `utils.checks` folder, just
like we're doing with most of the other decorators. This is to allow us
the flexibility to use it as a pure check, not only as a decorator.
This commit doesn't actually change any functionality, just moves it
around.
Diffstat (limited to '')
| -rw-r--r-- | bot/decorators.py | 54 | ||||
| -rw-r--r-- | bot/utils/checks.py | 81 | ||||
| -rw-r--r-- | tests/bot/test_decorators.py | 4 | 
3 files changed, 86 insertions, 53 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index 306f0830c..1e77afe60 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -9,37 +9,20 @@ from weakref import WeakValueDictionary  from discord import Colour, Embed, Member  from discord.errors import NotFound  from discord.ext import commands -from discord.ext.commands import CheckFailure, Cog, Context +from discord.ext.commands import Cog, Context  from bot.constants import Channels, ERROR_REPLIES, RedirectOutput -from bot.utils.checks import with_role_check, without_role_check +from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check  log = logging.getLogger(__name__) -class InWhitelistCheckFailure(CheckFailure): -    """Raised when the `in_whitelist` check fails.""" - -    def __init__(self, redirect_channel: Optional[int]) -> None: -        self.redirect_channel = redirect_channel - -        if redirect_channel: -            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" -        else: -            redirect_message = "" - -        error_message = f"You are not allowed to use that command{redirect_message}." - -        super().__init__(error_message) - -  def in_whitelist(      *,      channels: Container[int] = (),      categories: Container[int] = (),      roles: Container[int] = (),      redirect: Optional[int] = Channels.bot_commands, -  ) -> Callable:      """      Check if a command was issued in a whitelisted context. @@ -54,36 +37,9 @@ def in_whitelist(      redirected to the `redirect` channel that was passed (default: #bot-commands) or simply      told that they're not allowed to use this particular command (if `None` was passed).      """ -    if redirect and redirect not in channels: -        # It does not make sense for the channel whitelist to not contain the redirection -        # channel (if applicable). That's why we add the redirection channel to the `channels` -        # container if it's not already in it. As we allow any container type to be passed, -        # we first create a tuple in order to safely add the redirection channel. -        # -        # Note: It's possible for the redirect channel to be in a whitelisted category, but -        # there's no easy way to check that and as a channel can easily be moved in and out of -        # categories, it's probably not wise to rely on its category in any case. -        channels = tuple(channels) + (redirect,) -      def predicate(ctx: Context) -> bool: -        """Check if a command was issued in a whitelisted context.""" -        if channels and ctx.channel.id in channels: -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") -            return True - -        # Only check the category id if we have a category whitelist and the channel has a `category_id` -        if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") -            return True - -        # Only check the roles whitelist if we have one and ensure the author's roles attribute returns -        # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). -        if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") -            return True - -        log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") -        raise InWhitelistCheckFailure(redirect) +        """Check if command was issued in a whitelisted context.""" +        return in_whitelist_check(ctx, channels, categories, roles, redirect)      return commands.check(predicate) @@ -121,7 +77,7 @@ def locked() -> Callable:                  embed = Embed()                  embed.colour = Colour.red() -                log.debug(f"User tried to invoke a locked command.") +                log.debug("User tried to invoke a locked command.")                  embed.description = (                      "You're already using this command. Please wait until it is done before you use it again."                  ) diff --git a/bot/utils/checks.py b/bot/utils/checks.py index db56c347c..63568b29e 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,12 +1,89 @@  import datetime  import logging -from typing import Callable, Iterable +from typing import Callable, Container, Iterable, Optional -from discord.ext.commands import BucketType, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping +from discord.ext.commands import ( +    BucketType, +    CheckFailure, +    Cog, +    Command, +    CommandOnCooldown, +    Context, +    Cooldown, +    CooldownMapping, +) + +from bot import constants  log = logging.getLogger(__name__) +class InWhitelistCheckFailure(CheckFailure): +    """Raised when the `in_whitelist` check fails.""" + +    def __init__(self, redirect_channel: Optional[int]) -> None: +        self.redirect_channel = redirect_channel + +        if redirect_channel: +            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +        else: +            redirect_message = "" + +        error_message = f"You are not allowed to use that command{redirect_message}." + +        super().__init__(error_message) + + +def in_whitelist_check( +    ctx: Context, +    channels: Container[int] = (), +    categories: Container[int] = (), +    roles: Container[int] = (), +    redirect: Optional[int] = constants.Channels.bot_commands, +) -> bool: +    """ +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: + +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    If the command was invoked in a context that was not whitelisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). +    """ +    if redirect and redirect not in channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if applicable). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        channels = tuple(channels) + (redirect,) + +    if channels and ctx.channel.id in channels: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") +        return True + +    # Only check the category id if we have a category whitelist and the channel has a `category_id` +    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +        return True + +    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +        return True + +    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") +    raise InWhitelistCheckFailure(redirect) + +  def with_role_check(ctx: Context, *role_ids: int) -> bool:      """Returns True if the user has any one of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index a17dd3e16..3d450caa0 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -3,10 +3,10 @@ import unittest  import unittest.mock  from bot import constants -from bot.decorators import InWhitelistCheckFailure, in_whitelist +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) | 
