diff options
-rw-r--r-- | bot/cogs/error_handler.py | 6 | ||||
-rw-r--r-- | bot/cogs/information.py | 10 | ||||
-rw-r--r-- | bot/cogs/snekbox.py | 11 | ||||
-rw-r--r-- | bot/cogs/utils.py | 8 | ||||
-rw-r--r-- | bot/cogs/verification.py | 18 | ||||
-rw-r--r-- | bot/decorators.py | 84 | ||||
-rw-r--r-- | tests/bot/cogs/test_information.py | 4 |
7 files changed, 94 insertions, 47 deletions
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index dae283c6a..3f56a9798 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -9,7 +9,7 @@ from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels from bot.converters import TagNameConverter -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure log = logging.getLogger(__name__) @@ -202,7 +202,7 @@ class ErrorHandler(Cog): * BotMissingRole * BotMissingAnyRole * NoPrivateMessage - * InChannelCheckFailure + * InWhitelistedContextCheckFailure """ bot_missing_errors = ( errors.BotMissingPermissions, @@ -215,7 +215,7 @@ class ErrorHandler(Cog): await ctx.send( f"Sorry, it looks like I don't have the permissions or roles I need to do that." ) - elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)): + elif isinstance(e, (InWhitelistedContextCheckFailure, errors.NoPrivateMessage)): ctx.bot.stats.incr("errors.wrong_channel_or_dm_error") await ctx.send(e) diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 7921a4932..6b3fc0c96 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown from bot import constants from bot.bot import Bot -from bot.decorators import InChannelCheckFailure, in_channel, with_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, with_role from bot.pagination import LinePaginator from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -152,7 +152,7 @@ class Information(Cog): # Non-staff may only do this in #bot-commands if not with_role_check(ctx, *constants.STAFF_ROLES): if not ctx.channel.id == constants.Channels.bot_commands: - raise InChannelCheckFailure(constants.Channels.bot_commands) + raise InWhitelistedContextCheckFailure(constants.Channels.bot_commands) embed = await self.create_user_embed(ctx, user) @@ -331,7 +331,11 @@ class Information(Cog): @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) @group(invoke_without_command=True) - @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + whitelisted_roles=constants.STAFF_ROLES, + redirect_channel=constants.Channels.bot_commands, + ) async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: """Shows information about the raw API response.""" # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 315383b12..8827cb585 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot from bot.constants import Channels, Roles, URLs -from bot.decorators import in_channel +from bot.decorators import in_whitelisted_context from bot.utils.messages import wait_for_deletion log = logging.getLogger(__name__) @@ -38,6 +38,9 @@ RAW_CODE_REGEX = re.compile( ) MAX_PASTE_LEN = 1000 + +# `!eval` command whitelists +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 @@ -265,7 +268,11 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) + @in_whitelisted_context( + whitelisted_channels=EVAL_CHANNELS, + whitelisted_roles=EVAL_ROLES, + redirect_channel=Channels.bot_commands, + ) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 3ed471bbf..234ec514d 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -13,7 +13,7 @@ from discord.ext.commands import BadArgument, Cog, Context, command from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES -from bot.decorators import in_channel, with_role +from bot.decorators import in_whitelisted_context, with_role from bot.utils.time import humanize_delta log = logging.getLogger(__name__) @@ -118,7 +118,11 @@ class Utils(Cog): await ctx.message.channel.send(embed=pep_embed) @command() - @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES) + @in_whitelisted_context( + whitelisted_channels=(Channels.bot_commands,), + whitelisted_roles=STAFF_ROLES, + redirect_channel=Channels.bot_commands, + ) async def charinfo(self, ctx: Context, *, characters: str) -> None: """Shows you information on up to 25 unicode characters.""" match = re.match(r"<(a?):(\w+):(\d+)>", characters) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b0a493e68..040f52fbf 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command from bot import constants from bot.bot import Bot from bot.cogs.moderation import ModLog -from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.decorators import InWhitelistedContextCheckFailure, in_whitelisted_context, without_role from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class Verification(Cog): @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) @without_role(constants.Roles.verified) - @in_channel(constants.Channels.verification) + @in_whitelisted_context(whitelisted_channels=(constants.Channels.verification,)) async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Accept our rules and gain access to the rest of the server.""" log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.") @@ -138,7 +138,10 @@ class Verification(Cog): await ctx.message.delete() @command(name='subscribe') - @in_channel(constants.Channels.bot_commands) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + redirect_channel=constants.Channels.bot_commands, + ) async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Subscribe to announcement notifications by assigning yourself the role.""" has_role = False @@ -162,7 +165,10 @@ class Verification(Cog): ) @command(name='unsubscribe') - @in_channel(constants.Channels.bot_commands) + @in_whitelisted_context( + whitelisted_channels=(constants.Channels.bot_commands,), + redirect_channel=constants.Channels.bot_commands, + ) async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Unsubscribe from announcement notifications by removing the role from yourself.""" has_role = False @@ -187,8 +193,8 @@ class Verification(Cog): # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: - """Check for & ignore any InChannelCheckFailure.""" - if isinstance(error, InChannelCheckFailure): + """Check for & ignore any InWhitelistedContextCheckFailure.""" + if isinstance(error, InWhitelistedContextCheckFailure): error.handled = True @staticmethod diff --git a/bot/decorators.py b/bot/decorators.py index 2d18eaa6a..149564d18 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -3,7 +3,7 @@ import random from asyncio import Lock, sleep from contextlib import suppress from functools import wraps -from typing import Callable, Container, Union +from typing import Callable, Container, Optional, Union from weakref import WeakValueDictionary from discord import Colour, Embed, Member @@ -17,48 +17,74 @@ from bot.utils.checks import with_role_check, without_role_check log = logging.getLogger(__name__) -class InChannelCheckFailure(CheckFailure): - """Raised when a check fails for a message being sent in a whitelisted channel.""" +class InWhitelistedContextCheckFailure(CheckFailure): + """Raised when the `in_whitelist` check fails.""" - def __init__(self, *channels: int): - self.channels = channels - channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) + def __init__(self, redirect_channel: Optional[int] = None): + error_message = "Sorry, but you are not allowed to use that command here." - super().__init__(f"Sorry, but you may only use this command within {channels_str}.") + if redirect_channel: + error_message += f" Please use the <#{redirect_channel}> channel instead." + super().__init__(error_message) + + +def in_whitelisted_context( + *, + whitelisted_channels: Container[int] = (), + whitelisted_categories: Container[int] = (), + whitelisted_roles: Container[int] = (), + redirect_channel: Optional[int] = None, -def in_channel( - *channels: int, - hidden_channels: Container[int] = None, - bypass_roles: Container[int] = None ) -> Callable: """ - Checks that the message is in a whitelisted channel or optionally has a bypass role. + Check if a command was issued in a whitelisted context. + + The whitelists that can be provided are: - Hidden channels are channels which will not be displayed in the InChannelCheckFailure error - message. + - `channels`: a container with channel ids for whitelisted channels + - `categories`: a container with category ids for whitelisted categories + - `roles`: a container with with role ids for whitelisted roles + + An optional `redirect_channel` can be provided to redirect users that are not + authorized to use the command in the current context. If no such channel is + provided, the users are simply told that they are not authorized to use the + command. """ - hidden_channels = hidden_channels or [] - bypass_roles = bypass_roles or [] + if redirect_channel and redirect_channel not in whitelisted_channels: + # It does not make sense for the channel whitelist to not contain the redirection + # channel (if provided). That's why we add the redirection channel to the `channels` + # container if it's not already in it. As we allow any container type to be passed, + # we first create a tuple in order to safely add the redirection channel. + # + # Note: It's possible for the redirect channel to be in a whitelisted category, but + # there's no easy way to check that and as a channel can easily be moved in and out of + # categories, it's probably not wise to rely on its category in any case. + whitelisted_channels = tuple(whitelisted_channels) + (redirect_channel,) def predicate(ctx: Context) -> bool: - """In-channel checker predicate.""" - if ctx.channel.id in channels or ctx.channel.id in hidden_channels: - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The command was used in a whitelisted channel.") + """Check if a command was issued in a whitelisted context.""" + if whitelisted_channels and ctx.channel.id in whitelisted_channels: + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.") return True - if bypass_roles: - if any(r.id in bypass_roles for r in ctx.author.roles): - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The command was not used in a whitelisted channel, " - f"but the author had a role to bypass the in_channel check.") - return True + # Only check the category id if we have a category whitelist and the channel has a `category_id` + if ( + whitelisted_categories + and hasattr(ctx.channel, "category_id") + and ctx.channel.category_id in whitelisted_categories + ): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.") + return True - log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The in_channel check failed.") + # Only check the roles whitelist if we have one and ensure the author's roles attribute returns + # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there). + if whitelisted_roles and any(r.id in whitelisted_roles for r in getattr(ctx.author, "roles", ())): + log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.") + return True - raise InChannelCheckFailure(*channels) + log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.") + raise InWhitelistedContextCheckFailure(redirect_channel) return commands.check(predicate) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 3c26374f5..4a36fe030 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,7 +7,7 @@ import discord from bot import constants from bot.cogs import information -from bot.decorators import InChannelCheckFailure +from bot.decorators import InWhitelistedContextCheckFailure from tests import helpers @@ -525,7 +525,7 @@ class UserCommandTests(unittest.TestCase): ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) msg = "Sorry, but you may only use this command within <#50>." - with self.assertRaises(InChannelCheckFailure, msg=msg): + with self.assertRaises(InWhitelistedContextCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) |