diff options
| author | 2020-05-27 19:30:16 +0200 | |
|---|---|---|
| committer | 2020-05-27 19:30:16 +0200 | |
| commit | 9027724d935f2f9ba754e5ad6391a63aa324824f (patch) | |
| tree | 08b1d230da04f6613c39dff19a6e105f924d3e88 | |
| parent | Add /r/FlutterDev to the guild invite whitelist (diff) | |
| parent | Add some tests for `in_whitelist_check`. (diff) | |
Merge pull request #961 from python-discord/moderation_commands_in_modmail_category
Permit moderation commands in ModMail category
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/error_handler.py | 6 | ||||
| -rw-r--r-- | bot/cogs/information.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 22 | ||||
| -rw-r--r-- | bot/cogs/verification.py | 4 | ||||
| -rw-r--r-- | bot/constants.py | 5 | ||||
| -rw-r--r-- | bot/decorators.py | 55 | ||||
| -rw-r--r-- | bot/utils/checks.py | 94 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 3 | ||||
| -rw-r--r-- | tests/bot/test_decorators.py | 4 | ||||
| -rw-r--r-- | tests/bot/utils/test_checks.py | 52 | 
10 files changed, 161 insertions, 88 deletions
| diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 23d1eed82..5de961116 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Channels  from bot.converters import TagNameConverter -from bot.decorators import InWhitelistCheckFailure +from bot.utils.checks import InWhitelistCheckFailure  log = logging.getLogger(__name__) @@ -166,7 +166,7 @@ class ErrorHandler(Cog):              await prepared_help_command              self.bot.stats.incr("errors.missing_required_argument")          elif isinstance(e, errors.TooManyArguments): -            await ctx.send(f"Too many arguments provided.") +            await ctx.send("Too many arguments provided.")              await prepared_help_command              self.bot.stats.incr("errors.too_many_arguments")          elif isinstance(e, errors.BadArgument): @@ -206,7 +206,7 @@ class ErrorHandler(Cog):          if isinstance(e, bot_missing_errors):              ctx.bot.stats.incr("errors.bot_permission_error")              await ctx.send( -                f"Sorry, it looks like I don't have the permissions or roles I need to do that." +                "Sorry, it looks like I don't have the permissions or roles I need to do that."              )          elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)):              ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") diff --git a/bot/cogs/information.py b/bot/cogs/information.py index ef2f308ca..f0eb3a1ea 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,9 +12,9 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.decorators import InWhitelistCheckFailure, in_whitelist, with_role +from bot.decorators import in_whitelist, with_role  from bot.pagination import LinePaginator -from bot.utils.checks import cooldown_with_role_bypass, with_role_check +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check  from bot.utils.time import time_since  log = logging.getLogger(__name__) diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index edfdfd9e2..c39c7f3bc 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -12,7 +12,7 @@ from bot.bot import Bot  from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user  from bot.pagination import LinePaginator  from bot.utils import time -from bot.utils.checks import in_channel_check, with_role_check +from bot.utils.checks import in_whitelist_check, with_role_check  from . import utils  from .infractions import Infractions  from .modlog import ModLog @@ -49,8 +49,8 @@ class ModManagement(commands.Cog):      async def infraction_edit(          self,          ctx: Context, -        infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], -        duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], +        infraction_id: t.Union[int, allowed_strings("l", "last", "recent")],  # noqa: F821 +        duration: t.Union[Expiry, allowed_strings("p", "permanent"), None],   # noqa: F821          *,          reason: str = None      ) -> None: @@ -83,14 +83,14 @@ class ModManagement(commands.Cog):                  "actor__id": ctx.author.id,                  "ordering": "-inserted_at"              } -            infractions = await self.bot.api_client.get(f"bot/infractions", params=params) +            infractions = await self.bot.api_client.get("bot/infractions", params=params)              if infractions:                  old_infraction = infractions[0]                  infraction_id = old_infraction["id"]              else:                  await ctx.send( -                    f":x: Couldn't find most recent infraction; you have never given an infraction." +                    ":x: Couldn't find most recent infraction; you have never given an infraction."                  )                  return          else: @@ -224,7 +224,7 @@ class ModManagement(commands.Cog):      ) -> None:          """Send a paginated embed of infractions for the specified user."""          if not infractions: -            await ctx.send(f":warning: No infractions could be found for that query.") +            await ctx.send(":warning: No infractions could be found for that query.")              return          lines = tuple( @@ -283,10 +283,16 @@ class ModManagement(commands.Cog):      # This cannot be static (must have a __func__ attribute).      def cog_check(self, ctx: Context) -> bool: -        """Only allow moderators from moderator channels to invoke the commands in this cog.""" +        """Only allow moderators inside moderator channels to invoke the commands in this cog."""          checks = [              with_role_check(ctx, *constants.MODERATION_ROLES), -            in_channel_check(ctx, *constants.MODERATION_CHANNELS) +            in_whitelist_check( +                ctx, +                channels=constants.MODERATION_CHANNELS, +                categories=[constants.Categories.modmail], +                redirect=None, +                fail_silently=True, +            )          ]          return all(checks) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 77e8b5706..99be3cdaa 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,8 +9,8 @@ from discord.ext.commands import Cog, Context, command  from bot import constants  from bot.bot import Bot  from bot.cogs.moderation import ModLog -from bot.decorators import InWhitelistCheckFailure, in_whitelist, without_role -from bot.utils.checks import without_role_check +from bot.decorators import in_whitelist, without_role +from bot.utils.checks import InWhitelistCheckFailure, without_role_check  log = logging.getLogger(__name__) diff --git a/bot/constants.py b/bot/constants.py index 39de2ee41..2ce5355be 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -612,13 +612,10 @@ PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))  MODERATION_ROLES = Guild.moderation_roles  STAFF_ROLES = Guild.staff_roles -# Roles combinations +# Channel combinations  STAFF_CHANNELS = Guild.staff_channels - -# Default Channel combinations  MODERATION_CHANNELS = Guild.moderation_channels -  # Bot replies  NEGATIVE_REPLIES = [      "Noooooo!!", diff --git a/bot/decorators.py b/bot/decorators.py index 306f0830c..500197c89 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -9,37 +9,21 @@ from weakref import WeakValueDictionary  from discord import Colour, Embed, Member  from discord.errors import NotFound  from discord.ext import commands -from discord.ext.commands import CheckFailure, Cog, Context +from discord.ext.commands import Cog, Context  from bot.constants import Channels, ERROR_REPLIES, RedirectOutput -from bot.utils.checks import with_role_check, without_role_check +from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check  log = logging.getLogger(__name__) -class InWhitelistCheckFailure(CheckFailure): -    """Raised when the `in_whitelist` check fails.""" - -    def __init__(self, redirect_channel: Optional[int]) -> None: -        self.redirect_channel = redirect_channel - -        if redirect_channel: -            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" -        else: -            redirect_message = "" - -        error_message = f"You are not allowed to use that command{redirect_message}." - -        super().__init__(error_message) - -  def in_whitelist(      *,      channels: Container[int] = (),      categories: Container[int] = (),      roles: Container[int] = (),      redirect: Optional[int] = Channels.bot_commands, - +    fail_silently: bool = False,  ) -> Callable:      """      Check if a command was issued in a whitelisted context. @@ -54,36 +38,9 @@ def in_whitelist(      redirected to the `redirect` channel that was passed (default: #bot-commands) or simply      told that they're not allowed to use this particular command (if `None` was passed).      """ -    if redirect and redirect not in channels: -        # It does not make sense for the channel whitelist to not contain the redirection -        # channel (if applicable). That's why we add the redirection channel to the `channels` -        # container if it's not already in it. As we allow any container type to be passed, -        # we first create a tuple in order to safely add the redirection channel. -        # -        # Note: It's possible for the redirect channel to be in a whitelisted category, but -        # there's no easy way to check that and as a channel can easily be moved in and out of -        # categories, it's probably not wise to rely on its category in any case. -        channels = tuple(channels) + (redirect,) -      def predicate(ctx: Context) -> bool: -        """Check if a command was issued in a whitelisted context.""" -        if channels and ctx.channel.id in channels: -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") -            return True - -        # Only check the category id if we have a category whitelist and the channel has a `category_id` -        if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") -            return True - -        # Only check the roles whitelist if we have one and ensure the author's roles attribute returns -        # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). -        if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): -            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") -            return True - -        log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") -        raise InWhitelistCheckFailure(redirect) +        """Check if command was issued in a whitelisted context.""" +        return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently)      return commands.check(predicate) @@ -121,7 +78,7 @@ def locked() -> Callable:                  embed = Embed()                  embed.colour = Colour.red() -                log.debug(f"User tried to invoke a locked command.") +                log.debug("User tried to invoke a locked command.")                  embed.description = (                      "You're already using this command. Please wait until it is done before you use it again."                  ) diff --git a/bot/utils/checks.py b/bot/utils/checks.py index db56c347c..f0ef36302 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,12 +1,94 @@  import datetime  import logging -from typing import Callable, Iterable +from typing import Callable, Container, Iterable, Optional -from discord.ext.commands import BucketType, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping +from discord.ext.commands import ( +    BucketType, +    CheckFailure, +    Cog, +    Command, +    CommandOnCooldown, +    Context, +    Cooldown, +    CooldownMapping, +) + +from bot import constants  log = logging.getLogger(__name__) +class InWhitelistCheckFailure(CheckFailure): +    """Raised when the `in_whitelist` check fails.""" + +    def __init__(self, redirect_channel: Optional[int]) -> None: +        self.redirect_channel = redirect_channel + +        if redirect_channel: +            redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +        else: +            redirect_message = "" + +        error_message = f"You are not allowed to use that command{redirect_message}." + +        super().__init__(error_message) + + +def in_whitelist_check( +    ctx: Context, +    channels: Container[int] = (), +    categories: Container[int] = (), +    roles: Container[int] = (), +    redirect: Optional[int] = constants.Channels.bot_commands, +    fail_silently: bool = False, +) -> bool: +    """ +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: + +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    If the command was invoked in a context that was not whitelisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). +    """ +    if redirect and redirect not in channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if applicable). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        channels = tuple(channels) + (redirect,) + +    if channels and ctx.channel.id in channels: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") +        return True + +    # Only check the category id if we have a category whitelist and the channel has a `category_id` +    if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories: +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +        return True + +    # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +    # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +    if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())): +        log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +        return True + +    log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + +    # Some commands are secret, and should produce no feedback at all. +    if not fail_silently: +        raise InWhitelistCheckFailure(redirect) +    return False + +  def with_role_check(ctx: Context, *role_ids: int) -> bool:      """Returns True if the user has any one of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM @@ -38,14 +120,6 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:      return check -def in_channel_check(ctx: Context, *channel_ids: int) -> bool: -    """Checks if the command was executed inside the list of specified channels.""" -    check = ctx.channel.id in channel_ids -    log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -              f"The result of the in_channel check was {check}.") -    return check - -  def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,                                bypass_roles: Iterable[int]) -> Callable:      """ diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index b5f928dd6..aca6b594f 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,10 +7,9 @@ import discord  from bot import constants  from bot.cogs import information -from bot.decorators import InWhitelistCheckFailure +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  COG_PATH = "bot.cogs.information.Information" diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py index a17dd3e16..3d450caa0 100644 --- a/tests/bot/test_decorators.py +++ b/tests/bot/test_decorators.py @@ -3,10 +3,10 @@ import unittest  import unittest.mock  from bot import constants -from bot.decorators import InWhitelistCheckFailure, in_whitelist +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 9610771e5..de72e5748 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,6 +1,8 @@  import unittest +from unittest.mock import MagicMock  from bot.utils import checks +from bot.utils.checks import InWhitelistCheckFailure  from tests.helpers import MockContext, MockRole @@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):          self.ctx.author.roles.append(MockRole(id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) -    def test_in_channel_check_for_correct_channel(self): -        self.ctx.channel.id = 42 -        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_correct_channel(self): +        """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" +        channel_id = 3 +        self.ctx.channel.id = channel_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id])) -    def test_in_channel_check_for_incorrect_channel(self): -        self.ctx.channel.id = 42 + 10 -        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_incorrect_channel(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match.""" +        self.ctx.channel.id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, [4]) + +    def test_in_whitelist_check_correct_category(self): +        """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list.""" +        category_id = 3 +        self.ctx.channel.category_id = category_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id])) + +    def test_in_whitelist_check_incorrect_category(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match.""" +        self.ctx.channel.category_id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, categories=[4]) + +    def test_in_whitelist_check_correct_role(self): +        """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6])) + +    def test_in_whitelist_check_incorrect_role(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, roles=[4]) + +    def test_in_whitelist_check_fail_silently(self): +        """`in_whitelist_check` test no exception raised if `fail_silently` is `True`""" +        self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True)) + +    def test_in_whitelist_check_complex(self): +        """`in_whitelist_check` test with multiple parameters""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.ctx.channel.category_id = 3 +        self.ctx.channel.id = 5 +        self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2])) | 
