aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/error_handler.py6
-rw-r--r--bot/cogs/information.py6
-rw-r--r--bot/cogs/snekbox.py10
-rw-r--r--bot/cogs/utils.py4
-rw-r--r--bot/cogs/verification.py12
-rw-r--r--bot/decorators.py85
-rw-r--r--tests/bot/cogs/test_cogs.py4
-rw-r--r--tests/bot/cogs/test_information.py4
-rw-r--r--tests/bot/test_decorators.py147
-rw-r--r--tests/helpers.py23
10 files changed, 249 insertions, 52 deletions
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index dae283c6a..b2f4c59f6 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -9,7 +9,7 @@ from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels
from bot.converters import TagNameConverter
-from bot.decorators import InChannelCheckFailure
+from bot.decorators import InWhitelistCheckFailure
log = logging.getLogger(__name__)
@@ -202,7 +202,7 @@ class ErrorHandler(Cog):
* BotMissingRole
* BotMissingAnyRole
* NoPrivateMessage
- * InChannelCheckFailure
+ * InWhitelistCheckFailure
"""
bot_missing_errors = (
errors.BotMissingPermissions,
@@ -215,7 +215,7 @@ class ErrorHandler(Cog):
await ctx.send(
f"Sorry, it looks like I don't have the permissions or roles I need to do that."
)
- elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)):
+ elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)):
ctx.bot.stats.incr("errors.wrong_channel_or_dm_error")
await ctx.send(e)
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 7921a4932..4eb36c340 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -12,7 +12,7 @@ from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
-from bot.decorators import InChannelCheckFailure, in_channel, with_role
+from bot.decorators import InWhitelistCheckFailure, in_whitelist, with_role
from bot.pagination import LinePaginator
from bot.utils.checks import cooldown_with_role_bypass, with_role_check
from bot.utils.time import time_since
@@ -152,7 +152,7 @@ class Information(Cog):
# Non-staff may only do this in #bot-commands
if not with_role_check(ctx, *constants.STAFF_ROLES):
if not ctx.channel.id == constants.Channels.bot_commands:
- raise InChannelCheckFailure(constants.Channels.bot_commands)
+ raise InWhitelistCheckFailure(constants.Channels.bot_commands)
embed = await self.create_user_embed(ctx, user)
@@ -331,7 +331,7 @@ class Information(Cog):
@cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)
@group(invoke_without_command=True)
- @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES)
+ @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_ROLES)
async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None:
"""Shows information about the raw API response."""
# I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index 315383b12..8d4688114 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -12,8 +12,8 @@ from discord import HTTPException, Message, NotFound, Reaction, User
from discord.ext.commands import Cog, Context, command, guild_only
from bot.bot import Bot
-from bot.constants import Channels, Roles, URLs
-from bot.decorators import in_channel
+from bot.constants import Categories, Channels, Roles, URLs
+from bot.decorators import in_whitelist
from bot.utils.messages import wait_for_deletion
log = logging.getLogger(__name__)
@@ -38,6 +38,10 @@ RAW_CODE_REGEX = re.compile(
)
MAX_PASTE_LEN = 1000
+
+# `!eval` command whitelists
+EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric)
+EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use)
EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
SIGKILL = 9
@@ -265,7 +269,7 @@ class Snekbox(Cog):
@command(name="eval", aliases=("e",))
@guild_only()
- @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
+ @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES)
async def eval_command(self, ctx: Context, *, code: str = None) -> None:
"""
Run Python code and get the results.
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index 3ed471bbf..8023eb962 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -13,7 +13,7 @@ from discord.ext.commands import BadArgument, Cog, Context, command
from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES
-from bot.decorators import in_channel, with_role
+from bot.decorators import in_whitelist, with_role
from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
@@ -118,7 +118,7 @@ class Utils(Cog):
await ctx.message.channel.send(embed=pep_embed)
@command()
- @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES)
+ @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES)
async def charinfo(self, ctx: Context, *, characters: str) -> None:
"""Shows you information on up to 25 unicode characters."""
match = re.match(r"<(a?):(\w+):(\d+)>", characters)
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index b0a493e68..388b7a338 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command
from bot import constants
from bot.bot import Bot
from bot.cogs.moderation import ModLog
-from bot.decorators import InChannelCheckFailure, in_channel, without_role
+from bot.decorators import InWhitelistCheckFailure, in_whitelist, without_role
from bot.utils.checks import without_role_check
log = logging.getLogger(__name__)
@@ -122,7 +122,7 @@ class Verification(Cog):
@command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True)
@without_role(constants.Roles.verified)
- @in_channel(constants.Channels.verification)
+ @in_whitelist(channels=(constants.Channels.verification,))
async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Accept our rules and gain access to the rest of the server."""
log.debug(f"{ctx.author} called !accept. Assigning the 'Developer' role.")
@@ -138,7 +138,7 @@ class Verification(Cog):
await ctx.message.delete()
@command(name='subscribe')
- @in_channel(constants.Channels.bot_commands)
+ @in_whitelist(channels=(constants.Channels.bot_commands,))
async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Subscribe to announcement notifications by assigning yourself the role."""
has_role = False
@@ -162,7 +162,7 @@ class Verification(Cog):
)
@command(name='unsubscribe')
- @in_channel(constants.Channels.bot_commands)
+ @in_whitelist(channels=(constants.Channels.bot_commands,))
async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Unsubscribe from announcement notifications by removing the role from yourself."""
has_role = False
@@ -187,8 +187,8 @@ class Verification(Cog):
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
- """Check for & ignore any InChannelCheckFailure."""
- if isinstance(error, InChannelCheckFailure):
+ """Check for & ignore any InWhitelistCheckFailure."""
+ if isinstance(error, InWhitelistCheckFailure):
error.handled = True
@staticmethod
diff --git a/bot/decorators.py b/bot/decorators.py
index 2d18eaa6a..2ee5879f2 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -3,7 +3,7 @@ import random
from asyncio import Lock, sleep
from contextlib import suppress
from functools import wraps
-from typing import Callable, Container, Union
+from typing import Callable, Container, Optional, Union
from weakref import WeakValueDictionary
from discord import Colour, Embed, Member
@@ -11,54 +11,79 @@ from discord.errors import NotFound
from discord.ext import commands
from discord.ext.commands import CheckFailure, Cog, Context
-from bot.constants import ERROR_REPLIES, RedirectOutput
+from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
from bot.utils.checks import with_role_check, without_role_check
log = logging.getLogger(__name__)
-class InChannelCheckFailure(CheckFailure):
- """Raised when a check fails for a message being sent in a whitelisted channel."""
+class InWhitelistCheckFailure(CheckFailure):
+ """Raised when the `in_whitelist` check fails."""
- def __init__(self, *channels: int):
- self.channels = channels
- channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)
+ def __init__(self, redirect_channel: Optional[int]) -> None:
+ self.redirect_channel = redirect_channel
- super().__init__(f"Sorry, but you may only use this command within {channels_str}.")
+ if redirect_channel:
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ redirect_message = ""
+ error_message = f"You are not allowed to use that command{redirect_message}."
+
+ super().__init__(error_message)
+
+
+def in_whitelist(
+ *,
+ channels: Container[int] = (),
+ categories: Container[int] = (),
+ roles: Container[int] = (),
+ redirect: Optional[int] = Channels.bot_commands,
-def in_channel(
- *channels: int,
- hidden_channels: Container[int] = None,
- bypass_roles: Container[int] = None
) -> Callable:
"""
- Checks that the message is in a whitelisted channel or optionally has a bypass role.
+ Check if a command was issued in a whitelisted context.
+
+ The whitelists that can be provided are:
- Hidden channels are channels which will not be displayed in the InChannelCheckFailure error
- message.
+ - `channels`: a container with channel ids for whitelisted channels
+ - `categories`: a container with category ids for whitelisted categories
+ - `roles`: a container with with role ids for whitelisted roles
+
+ If the command was invoked in a context that was not whitelisted, the member is either
+ redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
+ told that they're not allowed to use this particular command (if `None` was passed).
"""
- hidden_channels = hidden_channels or []
- bypass_roles = bypass_roles or []
+ if redirect and redirect not in channels:
+ # It does not make sense for the channel whitelist to not contain the redirection
+ # channel (if applicable). That's why we add the redirection channel to the `channels`
+ # container if it's not already in it. As we allow any container type to be passed,
+ # we first create a tuple in order to safely add the redirection channel.
+ #
+ # Note: It's possible for the redirect channel to be in a whitelisted category, but
+ # there's no easy way to check that and as a channel can easily be moved in and out of
+ # categories, it's probably not wise to rely on its category in any case.
+ channels = tuple(channels) + (redirect,)
def predicate(ctx: Context) -> bool:
- """In-channel checker predicate."""
- if ctx.channel.id in channels or ctx.channel.id in hidden_channels:
- log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The command was used in a whitelisted channel.")
+ """Check if a command was issued in a whitelisted context."""
+ if channels and ctx.channel.id in channels:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
return True
- if bypass_roles:
- if any(r.id in bypass_roles for r in ctx.author.roles):
- log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The command was not used in a whitelisted channel, "
- f"but the author had a role to bypass the in_channel check.")
- return True
+ # Only check the category id if we have a category whitelist and the channel has a `category_id`
+ if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
+ return True
- log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The in_channel check failed.")
+ # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
+ # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
+ if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
+ return True
- raise InChannelCheckFailure(*channels)
+ log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
+ raise InWhitelistCheckFailure(redirect)
return commands.check(predicate)
diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py
index 39f6492cb..fdda59a8f 100644
--- a/tests/bot/cogs/test_cogs.py
+++ b/tests/bot/cogs/test_cogs.py
@@ -31,7 +31,7 @@ class CommandNameTests(unittest.TestCase):
def walk_modules() -> t.Iterator[ModuleType]:
"""Yield imported modules from the bot.cogs subpackage."""
def on_error(name: str) -> t.NoReturn:
- raise ImportError(name=name)
+ raise ImportError(name=name) # pragma: no cover
# The mock prevents asyncio.get_event_loop() from being called.
with mock.patch("discord.ext.tasks.loop"):
@@ -71,7 +71,7 @@ class CommandNameTests(unittest.TestCase):
for name in self.get_qualified_names(cmd):
with self.subTest(cmd=func_name, name=name):
- if name in all_names:
+ if name in all_names: # pragma: no cover
conflicts = ", ".join(all_names.get(name, ""))
self.fail(
f"Name '{name}' of the command {func_name} conflicts with {conflicts}."
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 3c26374f5..6dace1080 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,7 +7,7 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InChannelCheckFailure
+from bot.decorators import InWhitelistCheckFailure
from tests import helpers
@@ -525,7 +525,7 @@ class UserCommandTests(unittest.TestCase):
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
msg = "Sorry, but you may only use this command within <#50>."
- with self.assertRaises(InChannelCheckFailure, msg=msg):
+ with self.assertRaises(InWhitelistCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
@unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
new file mode 100644
index 000000000..a17dd3e16
--- /dev/null
+++ b/tests/bot/test_decorators.py
@@ -0,0 +1,147 @@
+import collections
+import unittest
+import unittest.mock
+
+from bot import constants
+from bot.decorators import InWhitelistCheckFailure, in_whitelist
+from tests import helpers
+
+
+InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
+
+
+class InWhitelistTests(unittest.TestCase):
+ """Tests for the `in_whitelist` check."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up helpers that only need to be defined once."""
+ cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456)
+ cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654)
+ cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666)
+ cls.dm_channel = helpers.MockDMChannel()
+
+ cls.non_staff_member = helpers.MockMember()
+ cls.staff_role = helpers.MockRole(id=121212)
+ cls.staff_member = helpers.MockMember(roles=(cls.staff_role,))
+
+ cls.channels = (cls.bot_commands.id,)
+ cls.categories = (cls.help_channel.category_id,)
+ cls.roles = (cls.staff_role.id,)
+
+ def test_predicate_returns_true_for_whitelisted_context(self):
+ """The predicate should return `True` if a whitelisted context was passed to it."""
+ test_cases = (
+ InWhitelistTestCase(
+ kwargs={"channels": self.channels},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="In whitelisted channels by members without whitelisted roles",
+ ),
+ InWhitelistTestCase(
+ kwargs={"redirect": self.bot_commands.id},
+ ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member),
+ description="`redirect` should be implicitly added to `channels`",
+ ),
+ InWhitelistTestCase(
+ kwargs={"categories": self.categories},
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member),
+ description="Whitelisted category without whitelisted role",
+ ),
+ InWhitelistTestCase(
+ kwargs={"roles": self.roles},
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member),
+ description="Whitelisted role outside of whitelisted channel/category"
+ ),
+ InWhitelistTestCase(
+ kwargs={
+ "channels": self.channels,
+ "categories": self.categories,
+ "roles": self.roles,
+ "redirect": self.bot_commands,
+ },
+ ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member),
+ description="Case with all whitelist kwargs used",
+ ),
+ )
+
+ for test_case in test_cases:
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ self.assertTrue(predicate(test_case.ctx))
+
+ def test_predicate_raises_exception_for_non_whitelisted_context(self):
+ """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context."""
+ test_cases = (
+ # Failing check with explicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": self.bot_commands.id,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an explicit redirect channel",
+ ),
+
+ # Failing check with implicit `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check with an implicit redirect channel",
+ ),
+
+ # Failing check without `redirect`
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member),
+ description="Failing check without a redirect channel",
+ ),
+
+ # Command issued in DM channel
+ InWhitelistTestCase(
+ kwargs={
+ "categories": self.categories,
+ "channels": self.channels,
+ "roles": self.roles,
+ "redirect": None,
+ },
+ ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me),
+ description="Commands issued in DM channel should be rejected",
+ ),
+ )
+
+ for test_case in test_cases:
+ if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None:
+ # There are two cases in which we have a redirect channel:
+ # 1. No redirect channel was passed; the default value of `bot_commands` is used
+ # 2. An explicit `redirect` is set that is "not None"
+ redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands)
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ # If an explicit `None` was passed for `redirect`, there is no redirect channel
+ redirect_message = ""
+
+ exception_message = f"You are not allowed to use that command{redirect_message}."
+
+ # patch `commands.check` with a no-op lambda that just returns the predicate passed to it
+ # so we can test the predicate that was generated from the specified kwargs.
+ with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate):
+ predicate = in_whitelist(**test_case.kwargs)
+
+ with self.subTest(test_description=test_case.description):
+ with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message):
+ predicate(test_case.ctx)
diff --git a/tests/helpers.py b/tests/helpers.py
index 8e13f0f28..2b79a6c2a 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -315,7 +315,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
spec_set = channel_instance
- def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
+ def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
@@ -323,6 +323,27 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
self.mention = f"#{self.name}"
+# Create data for the DMChannel instance
+state = unittest.mock.MagicMock()
+me = unittest.mock.MagicMock()
+dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]}
+dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data)
+
+
+class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A MagicMock subclass to mock TextChannel objects.
+
+ Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ spec_set = dm_channel_instance
+
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+
+
# Create a Message instance to get a realistic MagicMock of `discord.Message`
message_data = {
'id': 1,