aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/error_handler.py6
-rw-r--r--bot/cogs/information.py4
-rw-r--r--bot/cogs/moderation/management.py22
-rw-r--r--bot/cogs/verification.py4
-rw-r--r--bot/constants.py5
-rw-r--r--bot/decorators.py55
-rw-r--r--bot/utils/checks.py94
-rw-r--r--tests/bot/cogs/test_information.py3
-rw-r--r--tests/bot/test_decorators.py4
-rw-r--r--tests/bot/utils/test_checks.py52
10 files changed, 161 insertions, 88 deletions
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index 23d1eed82..5de961116 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -9,7 +9,7 @@ from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels
from bot.converters import TagNameConverter
-from bot.decorators import InWhitelistCheckFailure
+from bot.utils.checks import InWhitelistCheckFailure
log = logging.getLogger(__name__)
@@ -166,7 +166,7 @@ class ErrorHandler(Cog):
await prepared_help_command
self.bot.stats.incr("errors.missing_required_argument")
elif isinstance(e, errors.TooManyArguments):
- await ctx.send(f"Too many arguments provided.")
+ await ctx.send("Too many arguments provided.")
await prepared_help_command
self.bot.stats.incr("errors.too_many_arguments")
elif isinstance(e, errors.BadArgument):
@@ -206,7 +206,7 @@ class ErrorHandler(Cog):
if isinstance(e, bot_missing_errors):
ctx.bot.stats.incr("errors.bot_permission_error")
await ctx.send(
- f"Sorry, it looks like I don't have the permissions or roles I need to do that."
+ "Sorry, it looks like I don't have the permissions or roles I need to do that."
)
elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)):
ctx.bot.stats.incr("errors.wrong_channel_or_dm_error")
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index ef2f308ca..f0eb3a1ea 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -12,9 +12,9 @@ from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
-from bot.decorators import InWhitelistCheckFailure, in_whitelist, with_role
+from bot.decorators import in_whitelist, with_role
from bot.pagination import LinePaginator
-from bot.utils.checks import cooldown_with_role_bypass, with_role_check
+from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check
from bot.utils.time import time_since
log = logging.getLogger(__name__)
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index edfdfd9e2..c39c7f3bc 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -12,7 +12,7 @@ from bot.bot import Bot
from bot.converters import Expiry, InfractionSearchQuery, allowed_strings, proxy_user
from bot.pagination import LinePaginator
from bot.utils import time
-from bot.utils.checks import in_channel_check, with_role_check
+from bot.utils.checks import in_whitelist_check, with_role_check
from . import utils
from .infractions import Infractions
from .modlog import ModLog
@@ -49,8 +49,8 @@ class ModManagement(commands.Cog):
async def infraction_edit(
self,
ctx: Context,
- infraction_id: t.Union[int, allowed_strings("l", "last", "recent")],
- duration: t.Union[Expiry, allowed_strings("p", "permanent"), None],
+ infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], # noqa: F821
+ duration: t.Union[Expiry, allowed_strings("p", "permanent"), None], # noqa: F821
*,
reason: str = None
) -> None:
@@ -83,14 +83,14 @@ class ModManagement(commands.Cog):
"actor__id": ctx.author.id,
"ordering": "-inserted_at"
}
- infractions = await self.bot.api_client.get(f"bot/infractions", params=params)
+ infractions = await self.bot.api_client.get("bot/infractions", params=params)
if infractions:
old_infraction = infractions[0]
infraction_id = old_infraction["id"]
else:
await ctx.send(
- f":x: Couldn't find most recent infraction; you have never given an infraction."
+ ":x: Couldn't find most recent infraction; you have never given an infraction."
)
return
else:
@@ -224,7 +224,7 @@ class ModManagement(commands.Cog):
) -> None:
"""Send a paginated embed of infractions for the specified user."""
if not infractions:
- await ctx.send(f":warning: No infractions could be found for that query.")
+ await ctx.send(":warning: No infractions could be found for that query.")
return
lines = tuple(
@@ -283,10 +283,16 @@ class ModManagement(commands.Cog):
# This cannot be static (must have a __func__ attribute).
def cog_check(self, ctx: Context) -> bool:
- """Only allow moderators from moderator channels to invoke the commands in this cog."""
+ """Only allow moderators inside moderator channels to invoke the commands in this cog."""
checks = [
with_role_check(ctx, *constants.MODERATION_ROLES),
- in_channel_check(ctx, *constants.MODERATION_CHANNELS)
+ in_whitelist_check(
+ ctx,
+ channels=constants.MODERATION_CHANNELS,
+ categories=[constants.Categories.modmail],
+ redirect=None,
+ fail_silently=True,
+ )
]
return all(checks)
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index 77e8b5706..99be3cdaa 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -9,8 +9,8 @@ from discord.ext.commands import Cog, Context, command
from bot import constants
from bot.bot import Bot
from bot.cogs.moderation import ModLog
-from bot.decorators import InWhitelistCheckFailure, in_whitelist, without_role
-from bot.utils.checks import without_role_check
+from bot.decorators import in_whitelist, without_role
+from bot.utils.checks import InWhitelistCheckFailure, without_role_check
log = logging.getLogger(__name__)
diff --git a/bot/constants.py b/bot/constants.py
index 39de2ee41..2ce5355be 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -612,13 +612,10 @@ PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))
MODERATION_ROLES = Guild.moderation_roles
STAFF_ROLES = Guild.staff_roles
-# Roles combinations
+# Channel combinations
STAFF_CHANNELS = Guild.staff_channels
-
-# Default Channel combinations
MODERATION_CHANNELS = Guild.moderation_channels
-
# Bot replies
NEGATIVE_REPLIES = [
"Noooooo!!",
diff --git a/bot/decorators.py b/bot/decorators.py
index 306f0830c..500197c89 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -9,37 +9,21 @@ from weakref import WeakValueDictionary
from discord import Colour, Embed, Member
from discord.errors import NotFound
from discord.ext import commands
-from discord.ext.commands import CheckFailure, Cog, Context
+from discord.ext.commands import Cog, Context
from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
-from bot.utils.checks import with_role_check, without_role_check
+from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check
log = logging.getLogger(__name__)
-class InWhitelistCheckFailure(CheckFailure):
- """Raised when the `in_whitelist` check fails."""
-
- def __init__(self, redirect_channel: Optional[int]) -> None:
- self.redirect_channel = redirect_channel
-
- if redirect_channel:
- redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
- else:
- redirect_message = ""
-
- error_message = f"You are not allowed to use that command{redirect_message}."
-
- super().__init__(error_message)
-
-
def in_whitelist(
*,
channels: Container[int] = (),
categories: Container[int] = (),
roles: Container[int] = (),
redirect: Optional[int] = Channels.bot_commands,
-
+ fail_silently: bool = False,
) -> Callable:
"""
Check if a command was issued in a whitelisted context.
@@ -54,36 +38,9 @@ def in_whitelist(
redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
told that they're not allowed to use this particular command (if `None` was passed).
"""
- if redirect and redirect not in channels:
- # It does not make sense for the channel whitelist to not contain the redirection
- # channel (if applicable). That's why we add the redirection channel to the `channels`
- # container if it's not already in it. As we allow any container type to be passed,
- # we first create a tuple in order to safely add the redirection channel.
- #
- # Note: It's possible for the redirect channel to be in a whitelisted category, but
- # there's no easy way to check that and as a channel can easily be moved in and out of
- # categories, it's probably not wise to rely on its category in any case.
- channels = tuple(channels) + (redirect,)
-
def predicate(ctx: Context) -> bool:
- """Check if a command was issued in a whitelisted context."""
- if channels and ctx.channel.id in channels:
- log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
- return True
-
- # Only check the category id if we have a category whitelist and the channel has a `category_id`
- if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
- log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
- return True
-
- # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
- # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
- if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
- log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
- return True
-
- log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
- raise InWhitelistCheckFailure(redirect)
+ """Check if command was issued in a whitelisted context."""
+ return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently)
return commands.check(predicate)
@@ -121,7 +78,7 @@ def locked() -> Callable:
embed = Embed()
embed.colour = Colour.red()
- log.debug(f"User tried to invoke a locked command.")
+ log.debug("User tried to invoke a locked command.")
embed.description = (
"You're already using this command. Please wait until it is done before you use it again."
)
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index db56c347c..f0ef36302 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -1,12 +1,94 @@
import datetime
import logging
-from typing import Callable, Iterable
+from typing import Callable, Container, Iterable, Optional
-from discord.ext.commands import BucketType, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping
+from discord.ext.commands import (
+ BucketType,
+ CheckFailure,
+ Cog,
+ Command,
+ CommandOnCooldown,
+ Context,
+ Cooldown,
+ CooldownMapping,
+)
+
+from bot import constants
log = logging.getLogger(__name__)
+class InWhitelistCheckFailure(CheckFailure):
+ """Raised when the `in_whitelist` check fails."""
+
+ def __init__(self, redirect_channel: Optional[int]) -> None:
+ self.redirect_channel = redirect_channel
+
+ if redirect_channel:
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ redirect_message = ""
+
+ error_message = f"You are not allowed to use that command{redirect_message}."
+
+ super().__init__(error_message)
+
+
+def in_whitelist_check(
+ ctx: Context,
+ channels: Container[int] = (),
+ categories: Container[int] = (),
+ roles: Container[int] = (),
+ redirect: Optional[int] = constants.Channels.bot_commands,
+ fail_silently: bool = False,
+) -> bool:
+ """
+ Check if a command was issued in a whitelisted context.
+
+ The whitelists that can be provided are:
+
+ - `channels`: a container with channel ids for whitelisted channels
+ - `categories`: a container with category ids for whitelisted categories
+ - `roles`: a container with with role ids for whitelisted roles
+
+ If the command was invoked in a context that was not whitelisted, the member is either
+ redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
+ told that they're not allowed to use this particular command (if `None` was passed).
+ """
+ if redirect and redirect not in channels:
+ # It does not make sense for the channel whitelist to not contain the redirection
+ # channel (if applicable). That's why we add the redirection channel to the `channels`
+ # container if it's not already in it. As we allow any container type to be passed,
+ # we first create a tuple in order to safely add the redirection channel.
+ #
+ # Note: It's possible for the redirect channel to be in a whitelisted category, but
+ # there's no easy way to check that and as a channel can easily be moved in and out of
+ # categories, it's probably not wise to rely on its category in any case.
+ channels = tuple(channels) + (redirect,)
+
+ if channels and ctx.channel.id in channels:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
+ return True
+
+ # Only check the category id if we have a category whitelist and the channel has a `category_id`
+ if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
+ return True
+
+ # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
+ # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
+ if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
+ return True
+
+ log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
+
+ # Some commands are secret, and should produce no feedback at all.
+ if not fail_silently:
+ raise InWhitelistCheckFailure(redirect)
+ return False
+
+
def with_role_check(ctx: Context, *role_ids: int) -> bool:
"""Returns True if the user has any one of the roles in role_ids."""
if not ctx.guild: # Return False in a DM
@@ -38,14 +120,6 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:
return check
-def in_channel_check(ctx: Context, *channel_ids: int) -> bool:
- """Checks if the command was executed inside the list of specified channels."""
- check = ctx.channel.id in channel_ids
- log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the in_channel check was {check}.")
- return check
-
-
def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
bypass_roles: Iterable[int]) -> Callable:
"""
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index b5f928dd6..aca6b594f 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,10 +7,9 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InWhitelistCheckFailure
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
COG_PATH = "bot.cogs.information.Information"
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
index a17dd3e16..3d450caa0 100644
--- a/tests/bot/test_decorators.py
+++ b/tests/bot/test_decorators.py
@@ -3,10 +3,10 @@ import unittest
import unittest.mock
from bot import constants
-from bot.decorators import InWhitelistCheckFailure, in_whitelist
+from bot.decorators import in_whitelist
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 9610771e5..de72e5748 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -1,6 +1,8 @@
import unittest
+from unittest.mock import MagicMock
from bot.utils import checks
+from bot.utils.checks import InWhitelistCheckFailure
from tests.helpers import MockContext, MockRole
@@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):
self.ctx.author.roles.append(MockRole(id=role_id))
self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
- def test_in_channel_check_for_correct_channel(self):
- self.ctx.channel.id = 42
- self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_correct_channel(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list."""
+ channel_id = 3
+ self.ctx.channel.id = channel_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id]))
- def test_in_channel_check_for_incorrect_channel(self):
- self.ctx.channel.id = 42 + 10
- self.assertFalse(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_incorrect_channel(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match."""
+ self.ctx.channel.id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, [4])
+
+ def test_in_whitelist_check_correct_category(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list."""
+ category_id = 3
+ self.ctx.channel.category_id = category_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id]))
+
+ def test_in_whitelist_check_incorrect_category(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match."""
+ self.ctx.channel.category_id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, categories=[4])
+
+ def test_in_whitelist_check_correct_role(self):
+ """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6]))
+
+ def test_in_whitelist_check_incorrect_role(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, roles=[4])
+
+ def test_in_whitelist_check_fail_silently(self):
+ """`in_whitelist_check` test no exception raised if `fail_silently` is `True`"""
+ self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True))
+
+ def test_in_whitelist_check_complex(self):
+ """`in_whitelist_check` test with multiple parameters"""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.ctx.channel.category_id = 3
+ self.ctx.channel.id = 5
+ self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2]))