diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/error_handler.py | 6 | ||||
| -rw-r--r-- | bot/cogs/information.py | 10 | ||||
| -rw-r--r-- | bot/cogs/snekbox.py | 11 | ||||
| -rw-r--r-- | bot/cogs/utils.py | 8 | ||||
| -rw-r--r-- | bot/cogs/verification.py | 18 | ||||
| -rw-r--r-- | bot/decorators.py | 84 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 4 | 
7 files changed, 94 insertions, 47 deletions
| diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index dae283c6a..3f56a9798 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Channels  from bot.converters import TagNameConverter -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure  log = logging.getLogger(__name__) @@ -202,7 +202,7 @@ class ErrorHandler(Cog):          * BotMissingRole          * BotMissingAnyRole          * NoPrivateMessage -        * InChannelCheckFailure +        * InWhitelistedContextCheckFailure          """          bot_missing_errors = (              errors.BotMissingPermissions, @@ -215,7 +215,7 @@ class ErrorHandler(Cog):              await ctx.send(                  f"Sorry, it looks like I don't have the permissions or roles I need to do that."              ) -        elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)): +        elif isinstance(e, (InWhitelistedContextCheckFailure, errors.NoPrivateMessage)):              ctx.bot.stats.incr("errors.wrong_channel_or_dm_error")              await ctx.send(e) diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 7921a4932..6b3fc0c96 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.decorators import InChannelCheckFailure, in_channel, with_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, with_role  from bot.pagination import LinePaginator  from bot.utils.checks import cooldown_with_role_bypass, with_role_check  from bot.utils.time import time_since @@ -152,7 +152,7 @@ class Information(Cog):          # Non-staff may only do this in #bot-commands          if not with_role_check(ctx, *constants.STAFF_ROLES):              if not ctx.channel.id == constants.Channels.bot_commands: -                raise InChannelCheckFailure(constants.Channels.bot_commands) +                raise InWhitelistedContextCheckFailure(constants.Channels.bot_commands)          embed = await self.create_user_embed(ctx, user) @@ -331,7 +331,11 @@ class Information(Cog):      @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)      @group(invoke_without_command=True) -    @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES) +    @in_whitelisted_context( +        whitelisted_channels=(constants.Channels.bot_commands,), +        whitelisted_roles=constants.STAFF_ROLES, +        redirect_channel=constants.Channels.bot_commands, +    )      async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None:          """Shows information about the raw API response."""          # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 315383b12..8827cb585 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only  from bot.bot import Bot  from bot.constants import Channels, Roles, URLs -from bot.decorators import in_channel +from bot.decorators import in_whitelisted_context  from bot.utils.messages import wait_for_deletion  log = logging.getLogger(__name__) @@ -38,6 +38,9 @@ RAW_CODE_REGEX = re.compile(  )  MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric)  EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)  SIGKILL = 9 @@ -265,7 +268,11 @@ class Snekbox(Cog):      @command(name="eval", aliases=("e",))      @guild_only() -    @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) +    @in_whitelisted_context( +        whitelisted_channels=EVAL_CHANNELS, +        whitelisted_roles=EVAL_ROLES, +        redirect_channel=Channels.bot_commands, +    )      async def eval_command(self, ctx: Context, *, code: str = None) -> None:          """          Run Python code and get the results. diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 3ed471bbf..234ec514d 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -13,7 +13,7 @@ from discord.ext.commands import BadArgument, Cog, Context, command  from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES -from bot.decorators import in_channel, with_role +from bot.decorators import in_whitelisted_context, with_role  from bot.utils.time import humanize_delta  log = logging.getLogger(__name__) @@ -118,7 +118,11 @@ class Utils(Cog):          await ctx.message.channel.send(embed=pep_embed)      @command() -    @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES) +    @in_whitelisted_context( +        whitelisted_channels=(Channels.bot_commands,), +        whitelisted_roles=STAFF_ROLES, +        redirect_channel=Channels.bot_commands, +    )      async def charinfo(self, ctx: Context, *, characters: str) -> None:          """Shows you information on up to 25 unicode characters."""          match = re.match(r"<(a?):(\w+):(\d+)>", characters) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b0a493e68..040f52fbf 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command  from bot import constants  from bot.bot import Bot  from bot.cogs.moderation import ModLog -from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, without_role  from bot.utils.checks import without_role_check  log = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class Verification(Cog):      @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True)      @without_role(constants.Roles.verified) -    @in_channel(constants.Channels.verification) +    @in_whitelisted_context(whitelisted_channels=(constants.Channels.verification,))      async def accept_command(self, ctx: Context, *_) -> None:  # We don't actually care about the args          """Accept our rules and gain access to the rest of the server."""          log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") @@ -138,7 +138,10 @@ class Verification(Cog):                  await ctx.message.delete()      @command(name='subscribe') -    @in_channel(constants.Channels.bot_commands) +    @in_whitelisted_context( +        whitelisted_channels=(constants.Channels.bot_commands,), +        redirect_channel=constants.Channels.bot_commands, +    )      async def subscribe_command(self, ctx: Context, *_) -> None:  # We don't actually care about the args          """Subscribe to announcement notifications by assigning yourself the role."""          has_role = False @@ -162,7 +165,10 @@ class Verification(Cog):          )      @command(name='unsubscribe') -    @in_channel(constants.Channels.bot_commands) +    @in_whitelisted_context( +        whitelisted_channels=(constants.Channels.bot_commands,), +        redirect_channel=constants.Channels.bot_commands, +    )      async def unsubscribe_command(self, ctx: Context, *_) -> None:  # We don't actually care about the args          """Unsubscribe from announcement notifications by removing the role from yourself."""          has_role = False @@ -187,8 +193,8 @@ class Verification(Cog):      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: -        """Check for & ignore any InChannelCheckFailure.""" -        if isinstance(error, InChannelCheckFailure): +        """Check for & ignore any InWhitelistedContextCheckFailure.""" +        if isinstance(error, InWhitelistedContextCheckFailure):              error.handled = True      @staticmethod diff --git a/bot/decorators.py b/bot/decorators.py index 2d18eaa6a..149564d18 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -3,7 +3,7 @@ import random  from asyncio import Lock, sleep  from contextlib import suppress  from functools import wraps -from typing import Callable, Container, Union +from typing import Callable, Container, Optional, Union  from weakref import WeakValueDictionary  from discord import Colour, Embed, Member @@ -17,48 +17,74 @@ from bot.utils.checks import with_role_check, without_role_check  log = logging.getLogger(__name__) -class InChannelCheckFailure(CheckFailure): -    """Raised when a check fails for a message being sent in a whitelisted channel.""" +class InWhitelistedContextCheckFailure(CheckFailure): +    """Raised when the `in_whitelist` check fails.""" -    def __init__(self, *channels: int): -        self.channels = channels -        channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) +    def __init__(self, redirect_channel: Optional[int] = None): +        error_message = "Sorry, but you are not allowed to use that command here." -        super().__init__(f"Sorry, but you may only use this command within {channels_str}.") +        if redirect_channel: +            error_message += f" Please use the <#{redirect_channel}> channel instead." +        super().__init__(error_message) + + +def in_whitelisted_context( +    *, +    whitelisted_channels: Container[int] = (), +    whitelisted_categories: Container[int] = (), +    whitelisted_roles: Container[int] = (), +    redirect_channel: Optional[int] = None, -def in_channel( -    *channels: int, -    hidden_channels: Container[int] = None, -    bypass_roles: Container[int] = None  ) -> Callable:      """ -    Checks that the message is in a whitelisted channel or optionally has a bypass role. +    Check if a command was issued in a whitelisted context. + +    The whitelists that can be provided are: -    Hidden channels are channels which will not be displayed in the InChannelCheckFailure error -    message. +    - `channels`: a container with channel ids for whitelisted channels +    - `categories`: a container with category ids for whitelisted categories +    - `roles`: a container with with role ids for whitelisted roles + +    An optional `redirect_channel` can be provided to redirect users that are not +    authorized to use the command in the current context. If no such channel is +    provided, the users are simply told that they are not authorized to use the +    command.      """ -    hidden_channels = hidden_channels or [] -    bypass_roles = bypass_roles or [] +    if redirect_channel and redirect_channel not in whitelisted_channels: +        # It does not make sense for the channel whitelist to not contain the redirection +        # channel (if provided). That's why we add the redirection channel to the `channels` +        # container if it's not already in it. As we allow any container type to be passed, +        # we first create a tuple in order to safely add the redirection channel. +        # +        # Note: It's possible for the redirect channel to be in a whitelisted category, but +        # there's no easy way to check that and as a channel can easily be moved in and out of +        # categories, it's probably not wise to rely on its category in any case. +        whitelisted_channels = tuple(whitelisted_channels) + (redirect_channel,)      def predicate(ctx: Context) -> bool: -        """In-channel checker predicate.""" -        if ctx.channel.id in channels or ctx.channel.id in hidden_channels: -            log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                      f"The command was used in a whitelisted channel.") +        """Check if a command was issued in a whitelisted context.""" +        if whitelisted_channels and ctx.channel.id in whitelisted_channels: +            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")              return True -        if bypass_roles: -            if any(r.id in bypass_roles for r in ctx.author.roles): -                log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                          f"The command was not used in a whitelisted channel, " -                          f"but the author had a role to bypass the in_channel check.") -                return True +        # Only check the category id if we have a category whitelist and the channel has a `category_id` +        if ( +            whitelisted_categories +            and hasattr(ctx.channel, "category_id") +            and ctx.channel.category_id in whitelisted_categories +        ): +            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") +            return True -        log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -                  f"The in_channel check failed.") +        # Only check the roles whitelist if we have one and ensure the author's roles attribute returns +        # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). +        if whitelisted_roles and any(r.id in whitelisted_roles for r in getattr(ctx.author, "roles", ())): +            log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") +            return True -        raise InChannelCheckFailure(*channels) +        log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") +        raise InWhitelistedContextCheckFailure(redirect_channel)      return commands.check(predicate) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 3c26374f5..4a36fe030 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,7 +7,7 @@ import discord  from bot import constants  from bot.cogs import information -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure  from tests import helpers @@ -525,7 +525,7 @@ class UserCommandTests(unittest.TestCase):          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))          msg = "Sorry, but you may only use this command within <#50>." -        with self.assertRaises(InChannelCheckFailure, msg=msg): +        with self.assertRaises(InWhitelistedContextCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx))      @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) | 
