aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/decorators.py29
-rw-r--r--bot/exts/filters/filter_lists.py7
-rw-r--r--bot/exts/fun/off_topic_names.py15
-rw-r--r--bot/exts/help_channels.py7
-rw-r--r--bot/exts/info/doc.py7
-rw-r--r--bot/exts/info/information.py15
-rw-r--r--bot/exts/info/reddit.py5
-rw-r--r--bot/exts/moderation/defcon.py13
-rw-r--r--bot/exts/moderation/dm_relay.py6
-rw-r--r--bot/exts/moderation/infraction/infractions.py5
-rw-r--r--bot/exts/moderation/infraction/management.py6
-rw-r--r--bot/exts/moderation/infraction/superstarify.py7
-rw-r--r--bot/exts/moderation/silence.py5
-rw-r--r--bot/exts/moderation/slowmode.py7
-rw-r--r--bot/exts/moderation/verification.py15
-rw-r--r--bot/exts/moderation/watchchannels/bigbrother.py13
-rw-r--r--bot/exts/moderation/watchchannels/talentpool.py19
-rw-r--r--bot/exts/utils/bot.py11
-rw-r--r--bot/exts/utils/clean.py17
-rw-r--r--bot/exts/utils/eval.py7
-rw-r--r--bot/exts/utils/extensions.py5
-rw-r--r--bot/exts/utils/jams.py3
-rw-r--r--bot/exts/utils/reminders.py10
-rw-r--r--bot/exts/utils/utils.py6
-rw-r--r--bot/utils/checks.py49
-rw-r--r--tests/bot/exts/moderation/test_silence.py10
-rw-r--r--tests/bot/exts/moderation/test_slowmode.py10
-rw-r--r--tests/bot/utils/test_checks.py44
28 files changed, 173 insertions, 180 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index 500197c89..2518124da 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -6,13 +6,12 @@ from functools import wraps
from typing import Callable, Container, Optional, Union
from weakref import WeakValueDictionary
-from discord import Colour, Embed, Member
-from discord.errors import NotFound
+from discord import Colour, Embed, Member, NotFound
from discord.ext import commands
from discord.ext.commands import Cog, Context
from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
-from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check
+from bot.utils.checks import in_whitelist_check
log = logging.getLogger(__name__)
@@ -45,18 +44,22 @@ def in_whitelist(
return commands.check(predicate)
-def with_role(*role_ids: int) -> Callable:
- """Returns True if the user has any one of the roles in role_ids."""
- async def predicate(ctx: Context) -> bool:
- """With role checker predicate."""
- return with_role_check(ctx, *role_ids)
- return commands.check(predicate)
-
+def has_no_roles(*roles: Union[str, int]) -> Callable:
+ """
+ Returns True if the user does not have any of the roles specified.
-def without_role(*role_ids: int) -> Callable:
- """Returns True if the user does not have any of the roles in role_ids."""
+ `roles` are the names or IDs of the disallowed roles.
+ """
async def predicate(ctx: Context) -> bool:
- return without_role_check(ctx, *role_ids)
+ try:
+ await commands.has_any_role(*roles).predicate(ctx)
+ except commands.MissingAnyRole:
+ return True
+ else:
+ # This error is never shown to users, so don't bother trying to make it too pretty.
+ roles_ = ", ".join(f"'{item}'" for item in roles)
+ raise commands.CheckFailure(f"You have at least one of the disallowed roles: {roles_}")
+
return commands.check(predicate)
diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py
index c15adc461..232c1e48b 100644
--- a/bot/exts/filters/filter_lists.py
+++ b/bot/exts/filters/filter_lists.py
@@ -2,14 +2,13 @@ import logging
from typing import Optional
from discord import Colour, Embed
-from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group
+from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role
from bot import constants
from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.converters import ValidDiscordServerInvite, ValidFilterListType
from bot.pagination import LinePaginator
-from bot.utils.checks import with_role_check
log = logging.getLogger(__name__)
@@ -263,9 +262,9 @@ class FilterLists(Cog):
"""Syncs both allowlists and denylists with the API."""
await self._sync_data(ctx)
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *constants.MODERATION_ROLES)
+ return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)
def setup(bot: Bot) -> None:
diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py
index ce95450e0..b9d235fa2 100644
--- a/bot/exts/fun/off_topic_names.py
+++ b/bot/exts/fun/off_topic_names.py
@@ -4,13 +4,12 @@ import logging
from datetime import datetime, timedelta
from discord import Colour, Embed
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES
from bot.converters import OffTopicName
-from bot.decorators import with_role
from bot.pagination import LinePaginator
CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2)
@@ -67,13 +66,13 @@ class OffTopicNames(Cog):
self.updater_task = self.bot.loop.create_task(coro)
@group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True)
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def otname_group(self, ctx: Context) -> None:
"""Add or list items from the off-topic channel name rotation."""
await ctx.send_help(ctx.command)
@otname_group.command(name='add', aliases=('a',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def add_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""
Adds a new off-topic name to the rotation.
@@ -96,7 +95,7 @@ class OffTopicNames(Cog):
await self._add_name(ctx, name)
@otname_group.command(name='forceadd', aliases=('fa',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""Forcefully adds a new off-topic name to the rotation."""
await self._add_name(ctx, name)
@@ -109,7 +108,7 @@ class OffTopicNames(Cog):
await ctx.send(f":ok_hand: Added `{name}` to the names list.")
@otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd'))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""Removes a off-topic name from the rotation."""
await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}')
@@ -118,7 +117,7 @@ class OffTopicNames(Cog):
await ctx.send(f":ok_hand: Removed `{name}` from the names list.")
@otname_group.command(name='list', aliases=('l',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def list_command(self, ctx: Context) -> None:
"""
Lists all currently known off-topic channel names in a paginator.
@@ -138,7 +137,7 @@ class OffTopicNames(Cog):
await ctx.send(embed=embed)
@otname_group.command(name='search', aliases=('s',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def search_command(self, ctx: Context, *, query: OffTopicName) -> None:
"""Search for an off-topic name."""
result = await self.bot.api_client.get('bot/off-topic-channel-names')
diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py
index 0f9cac89e..17142071f 100644
--- a/bot/exts/help_channels.py
+++ b/bot/exts/help_channels.py
@@ -14,7 +14,6 @@ from discord.ext import commands
from bot import constants
from bot.bot import Bot
from bot.utils import RedisCache
-from bot.utils.checks import with_role_check
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
@@ -196,12 +195,12 @@ class HelpChannels(commands.Cog):
return True
log.trace(f"{ctx.author} is not the help channel claimant, checking roles.")
- role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist)
+ has_role = await commands.has_any_role(*constants.HelpChannels.cmd_whitelist).predicate(ctx)
- if role_check:
+ if has_role:
self.bot.stats.incr("help.dormant_invoke.staff")
- return role_check
+ return has_role
@commands.command(name="close", aliases=["dormant", "solved"], enabled=False)
async def close_command(self, ctx: commands.Context) -> None:
diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py
index 30c793c75..e50b9b32b 100644
--- a/bot/exts/info/doc.py
+++ b/bot/exts/info/doc.py
@@ -21,7 +21,6 @@ from urllib3.exceptions import ProtocolError
from bot.bot import Bot
from bot.constants import MODERATION_ROLES, RedirectOutput
from bot.converters import ValidPythonIdentifier, ValidURL
-from bot.decorators import with_role
from bot.pagination import LinePaginator
from bot.utils.messages import wait_for_deletion
@@ -396,7 +395,7 @@ class Doc(commands.Cog):
await wait_for_deletion(msg, (ctx.author.id,), client=self.bot)
@docs_group.command(name='set', aliases=('s',))
- @with_role(*MODERATION_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def set_command(
self, ctx: commands.Context, package_name: ValidPythonIdentifier,
base_url: ValidURL, inventory_url: InventoryURL
@@ -433,7 +432,7 @@ class Doc(commands.Cog):
await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.")
@docs_group.command(name='delete', aliases=('remove', 'rm', 'd'))
- @with_role(*MODERATION_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None:
"""
Removes the specified package from the database.
@@ -450,7 +449,7 @@ class Doc(commands.Cog):
await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.")
@docs_group.command(name="refresh", aliases=("rfsh", "r"))
- @with_role(*MODERATION_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def refresh_command(self, ctx: commands.Context) -> None:
"""Refresh inventories and send differences to channel."""
old_inventories = set(self.base_urls)
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 55ecb2836..581b3a227 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -8,18 +8,19 @@ from typing import Any, Mapping, Optional, Tuple, Union
from discord import ChannelType, Colour, CustomActivity, Embed, Guild, Member, Message, Role, Status, utils
from discord.abc import GuildChannel
-from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group
+from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role
from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
-from bot.decorators import in_whitelist, with_role
+from bot.decorators import in_whitelist
from bot.pagination import LinePaginator
-from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check
+from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, has_no_roles_check
from bot.utils.time import time_since
log = logging.getLogger(__name__)
+
STATUS_EMOTES = {
Status.offline: constants.Emojis.status_offline,
Status.dnd: constants.Emojis.status_dnd,
@@ -76,7 +77,7 @@ class Information(Cog):
channel_type_list = sorted(channel_type_list)
return "\n".join(channel_type_list)
- @with_role(*constants.MODERATION_ROLES)
+ @has_any_role(*constants.MODERATION_ROLES)
@command(name="roles")
async def roles_info(self, ctx: Context) -> None:
"""Returns a list of all roles and their corresponding IDs."""
@@ -96,7 +97,7 @@ class Information(Cog):
await LinePaginator.paginate(role_list, ctx, embed, empty=False)
- @with_role(*constants.MODERATION_ROLES)
+ @has_any_role(*constants.MODERATION_ROLES)
@command(name="role")
async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None:
"""
@@ -197,12 +198,12 @@ class Information(Cog):
user = ctx.author
# Do a role check if this is being executed on someone other than the caller
- elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
+ elif user != ctx.author and await has_no_roles_check(ctx, *constants.MODERATION_ROLES):
await ctx.send("You may not use this command on users other than yourself.")
return
# Non-staff may only do this in #bot-commands
- if not with_role_check(ctx, *constants.STAFF_ROLES):
+ if await has_no_roles_check(ctx, *constants.STAFF_ROLES):
if not ctx.channel.id == constants.Channels.bot_commands:
raise InWhitelistCheckFailure(constants.Channels.bot_commands)
diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py
index 5d9e2c20b..635162308 100644
--- a/bot/exts/info/reddit.py
+++ b/bot/exts/info/reddit.py
@@ -8,14 +8,13 @@ from typing import List
from aiohttp import BasicAuth, ClientError
from discord import Colour, Embed, TextChannel
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from discord.ext.tasks import loop
from discord.utils import escape_markdown
from bot.bot import Bot
from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
from bot.converters import Subreddit
-from bot.decorators import with_role
from bot.pagination import LinePaginator
from bot.utils.messages import sub_clyde
@@ -282,7 +281,7 @@ class Reddit(Cog):
await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)
- @with_role(*STAFF_ROLES)
+ @has_any_role(*STAFF_ROLES)
@reddit_group.command(name="subreddits", aliases=("subs",))
async def subreddits_command(self, ctx: Context) -> None:
"""Send a paginated embed of all the subreddits we're relaying."""
diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py
index 6e4008777..3bf462877 100644
--- a/bot/exts/moderation/defcon.py
+++ b/bot/exts/moderation/defcon.py
@@ -6,11 +6,10 @@ from datetime import datetime, timedelta
from enum import Enum
from discord import Colour, Embed, Member
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_ROLES, Roles
-from bot.decorators import with_role
from bot.exts.moderation.modlog import ModLog
log = logging.getLogger(__name__)
@@ -119,7 +118,7 @@ class Defcon(Cog):
)
@group(name='defcon', aliases=('dc',), invoke_without_command=True)
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def defcon_group(self, ctx: Context) -> None:
"""Check the DEFCON status or run a subcommand."""
await ctx.send_help(ctx.command)
@@ -163,7 +162,7 @@ class Defcon(Cog):
self.bot.stats.gauge("defcon.threshold", days)
@defcon_group.command(name='enable', aliases=('on', 'e'), root_aliases=("defon",))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def enable_command(self, ctx: Context) -> None:
"""
Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!
@@ -176,7 +175,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='disable', aliases=('off', 'd'), root_aliases=("defoff",))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def disable_command(self, ctx: Context) -> None:
"""Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!"""
self.enabled = False
@@ -184,7 +183,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='status', aliases=('s',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def status_command(self, ctx: Context) -> None:
"""Check the current status of DEFCON mode."""
embed = Embed(
@@ -196,7 +195,7 @@ class Defcon(Cog):
await ctx.send(embed=embed)
@defcon_group.command(name='days')
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def days_command(self, ctx: Context, days: int) -> None:
"""Set how old an account must be to join the server, in days, with DEFCON mode enabled."""
self.days = timedelta(days=days)
diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py
index 0d8f340b4..7a3fe49bb 100644
--- a/bot/exts/moderation/dm_relay.py
+++ b/bot/exts/moderation/dm_relay.py
@@ -10,7 +10,7 @@ from bot import constants
from bot.bot import Bot
from bot.converters import UserMentionOrID
from bot.utils import RedisCache
-from bot.utils.checks import in_whitelist_check, with_role_check
+from bot.utils.checks import in_whitelist_check
from bot.utils.messages import send_attachments
from bot.utils.webhooks import send_webhook
@@ -105,10 +105,10 @@ class DMRelay(Cog):
except discord.HTTPException:
log.exception("Failed to send an attachment to the webhook")
- def cog_check(self, ctx: commands.Context) -> bool:
+ async def cog_check(self, ctx: commands.Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
checks = [
- with_role_check(ctx, *constants.MODERATION_ROLES),
+ await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),
in_whitelist_check(
ctx,
channels=[constants.Channels.dm_log],
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index 84ea47371..5fa62d3c4 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -15,7 +15,6 @@ from bot.decorators import respect_role_hierarchy
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._scheduler import InfractionScheduler
from bot.exts.moderation.infraction._utils import UserSnowflake
-from bot.utils.checks import with_role_check
log = logging.getLogger(__name__)
@@ -357,9 +356,9 @@ class Infractions(InfractionScheduler, commands.Cog):
# endregion
# This cannot be static (must have a __func__ attribute).
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *constants.MODERATION_ROLES)
+ return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx)
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py
index 5875abd26..15ee28537 100644
--- a/bot/exts/moderation/infraction/management.py
+++ b/bot/exts/moderation/infraction/management.py
@@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions
from bot.exts.moderation.modlog import ModLog
from bot.pagination import LinePaginator
from bot.utils import time
-from bot.utils.checks import in_whitelist_check, with_role_check
+from bot.utils.checks import in_whitelist_check
log = logging.getLogger(__name__)
@@ -282,10 +282,10 @@ class ModManagement(commands.Cog):
# endregion
# This cannot be static (must have a __func__ attribute).
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators inside moderator channels to invoke the commands in this cog."""
checks = [
- with_role_check(ctx, *constants.MODERATION_ROLES),
+ await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),
in_whitelist_check(
ctx,
channels=constants.MODERATION_CHANNELS,
diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py
index a4e78c4d3..29f41f2ab 100644
--- a/bot/exts/moderation/infraction/superstarify.py
+++ b/bot/exts/moderation/infraction/superstarify.py
@@ -6,14 +6,13 @@ import typing as t
from pathlib import Path
from discord import Colour, Embed, Member
-from discord.ext.commands import Cog, Context, command
+from discord.ext.commands import Cog, Context, command, has_any_role
from bot import constants
from bot.bot import Bot
from bot.converters import Expiry
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._scheduler import InfractionScheduler
-from bot.utils.checks import with_role_check
from bot.utils.time import format_infraction
log = logging.getLogger(__name__)
@@ -234,9 +233,9 @@ class Superstarify(InfractionScheduler, Cog):
return rng.choice(STAR_NAMES)
# This cannot be static (must have a __func__ attribute).
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *constants.MODERATION_ROLES)
+ return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)
def setup(bot: Bot) -> None:
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index 4af87c724..ac0c1c85e 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -10,7 +10,6 @@ from discord.ext.commands import Context
from bot.bot import Bot
from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles
from bot.converters import HushDurationConverter
-from bot.utils.checks import with_role_check
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
@@ -160,9 +159,9 @@ class Silence(commands.Cog):
asyncio.create_task(self._mod_alerts_channel.send(message))
# This cannot be static (must have a __func__ attribute).
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *MODERATION_ROLES)
+ return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx)
def setup(bot: Bot) -> None:
diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py
index 1d055afac..efd862aa5 100644
--- a/bot/exts/moderation/slowmode.py
+++ b/bot/exts/moderation/slowmode.py
@@ -4,12 +4,11 @@ from typing import Optional
from dateutil.relativedelta import relativedelta
from discord import TextChannel
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
from bot.constants import Emojis, MODERATION_ROLES
from bot.converters import DurationDelta
-from bot.decorators import with_role_check
from bot.utils import time
log = logging.getLogger(__name__)
@@ -87,9 +86,9 @@ class Slowmode(Cog):
f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.'
)
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *MODERATION_ROLES)
+ return await has_any_role(*MODERATION_ROLES).predicate(ctx)
def setup(bot: Bot) -> None:
diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py
index 53fa0730b..8ec68ac1e 100644
--- a/bot/exts/moderation/verification.py
+++ b/bot/exts/moderation/verification.py
@@ -6,14 +6,14 @@ from datetime import datetime, timedelta
import discord
from discord.ext import tasks
-from discord.ext.commands import Cog, Context, command, group
+from discord.ext.commands import Cog, Context, command, group, has_any_role
from discord.utils import snowflake_time
from bot import constants
from bot.bot import Bot
-from bot.decorators import in_whitelist, with_role, without_role
+from bot.decorators import has_no_roles, in_whitelist
from bot.exts.moderation.modlog import ModLog
-from bot.utils.checks import InWhitelistCheckFailure, without_role_check
+from bot.utils.checks import InWhitelistCheckFailure, has_no_roles_check
from bot.utils.redis_cache import RedisCache
log = logging.getLogger(__name__)
@@ -568,7 +568,7 @@ class Verification(Cog):
# endregion
# region: task management commands
- @with_role(*constants.MODERATION_ROLES)
+ @has_any_role(*constants.MODERATION_ROLES)
@group(name="verification")
async def verification_group(self, ctx: Context) -> None:
"""Manage internal verification tasks."""
@@ -653,7 +653,7 @@ class Verification(Cog):
self.bot.stats.incr(f"verification.{category}")
@command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True)
- @without_role(constants.Roles.verified)
+ @has_no_roles(constants.Roles.verified)
@in_whitelist(channels=(constants.Channels.verification,))
async def accept_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Accept our rules and gain access to the rest of the server."""
@@ -736,9 +736,10 @@ class Verification(Cog):
error.handled = True
@staticmethod
- def bot_check(ctx: Context) -> bool:
+ async def bot_check(ctx: Context) -> bool:
"""Block any command within the verification channel that is not !accept."""
- if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES):
+ is_verification = ctx.channel.id == constants.Channels.verification
+ if is_verification and await has_no_roles_check(ctx, *constants.MODERATION_ROLES):
return ctx.command.name == "accept"
else:
return True
diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py
index d7127b5c4..3b44056d3 100644
--- a/bot/exts/moderation/watchchannels/bigbrother.py
+++ b/bot/exts/moderation/watchchannels/bigbrother.py
@@ -2,12 +2,11 @@ import logging
import textwrap
from collections import ChainMap
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES, Webhooks
from bot.converters import FetchedMember
-from bot.decorators import with_role
from bot.exts.moderation.infraction._utils import post_infraction
from bot.exts.moderation.watchchannels._watchchannel import WatchChannel
@@ -28,13 +27,13 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
)
@group(name='bigbrother', aliases=('bb',), invoke_without_command=True)
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def bigbrother_group(self, ctx: Context) -> None:
"""Monitors users by relaying their messages to the Big Brother watch channel."""
await ctx.send_help(ctx.command)
@bigbrother_group.command(name='watched', aliases=('all', 'list'))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def watched_command(
self, ctx: Context, oldest_first: bool = False, update_cache: bool = True
) -> None:
@@ -49,7 +48,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache)
@bigbrother_group.command(name='oldest')
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None:
"""
Shows Big Brother monitored users ordered by oldest watched.
@@ -60,7 +59,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache)
@bigbrother_group.command(name='watch', aliases=('w',), root_aliases=('watch',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:
"""
Relay messages sent by the given `user` to the `#big-brother` channel.
@@ -71,7 +70,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
await self.apply_watch(ctx, user, reason)
@bigbrother_group.command(name='unwatch', aliases=('uw',), root_aliases=('unwatch',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:
"""Stop relaying messages by the given `user`."""
await self.apply_unwatch(ctx, user, reason)
diff --git a/bot/exts/moderation/watchchannels/talentpool.py b/bot/exts/moderation/watchchannels/talentpool.py
index 3724e94e6..a77dbe156 100644
--- a/bot/exts/moderation/watchchannels/talentpool.py
+++ b/bot/exts/moderation/watchchannels/talentpool.py
@@ -4,13 +4,12 @@ from collections import ChainMap
from typing import Union
from discord import Color, Embed, Member, User
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks
from bot.converters import FetchedMember
-from bot.decorators import with_role
from bot.exts.moderation.watchchannels._watchchannel import WatchChannel
from bot.pagination import LinePaginator
from bot.utils import time
@@ -32,13 +31,13 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
)
@group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True)
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def nomination_group(self, ctx: Context) -> None:
"""Highlights the activity of helper nominees by relaying their messages to the talent pool channel."""
await ctx.send_help(ctx.command)
@nomination_group.command(name='watched', aliases=('all', 'list'), root_aliases=("nominees",))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def watched_command(
self, ctx: Context, oldest_first: bool = False, update_cache: bool = True
) -> None:
@@ -53,7 +52,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache)
@nomination_group.command(name='oldest')
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None:
"""
Shows talent pool monitored users ordered by oldest nomination.
@@ -64,7 +63,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache)
@nomination_group.command(name='watch', aliases=('w', 'add', 'a'), root_aliases=("nominate",))
- @with_role(*STAFF_ROLES)
+ @has_any_role(*STAFF_ROLES)
async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:
"""
Relay messages sent by the given `user` to the `#talent-pool` channel.
@@ -129,7 +128,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
await ctx.send(msg)
@nomination_group.command(name='history', aliases=('info', 'search'))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def history_command(self, ctx: Context, user: FetchedMember) -> None:
"""Shows the specified user's nomination history."""
result = await self.bot.api_client.get(
@@ -158,7 +157,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
)
@nomination_group.command(name='unwatch', aliases=('end', ), root_aliases=("unnominate",))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:
"""
Ends the active nomination of the specified user with the given reason.
@@ -171,13 +170,13 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
await ctx.send(":x: The specified user does not have an active nomination")
@nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True)
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def nomination_edit_group(self, ctx: Context) -> None:
"""Commands to edit nominations."""
await ctx.send_help(ctx.command)
@nomination_edit_group.command(name='reason')
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None:
"""
Edits the reason/unnominate reason for the nomination with the given `id` depending on the status.
diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py
index 66f340a99..7ed487d47 100644
--- a/bot/exts/utils/bot.py
+++ b/bot/exts/utils/bot.py
@@ -5,11 +5,10 @@ import time
from typing import Optional, Tuple
from discord import Embed, Message, RawMessageUpdateEvent, TextChannel
-from discord.ext.commands import Cog, Context, command, group
+from discord.ext.commands import Cog, Context, command, group, has_any_role
from bot.bot import Bot
from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs
-from bot.decorators import with_role
from bot.exts.filters.token_remover import TokenRemover
from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE
from bot.utils.messages import wait_for_deletion
@@ -39,13 +38,13 @@ class BotCog(Cog, name="Bot"):
self.codeblock_message_ids = {}
@group(invoke_without_command=True, name="bot", hidden=True)
- @with_role(Roles.verified)
+ @has_any_role(Roles.verified)
async def botinfo_group(self, ctx: Context) -> None:
"""Bot informational commands."""
await ctx.send_help(ctx.command)
@botinfo_group.command(name='about', aliases=('info',), hidden=True)
- @with_role(Roles.verified)
+ @has_any_role(Roles.verified)
async def about_command(self, ctx: Context) -> None:
"""Get information about the bot."""
embed = Embed(
@@ -63,7 +62,7 @@ class BotCog(Cog, name="Bot"):
await ctx.send(embed=embed)
@command(name='echo', aliases=('print',))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None:
"""Repeat the given message in either a specified channel or the current channel."""
if channel is None:
@@ -72,7 +71,7 @@ class BotCog(Cog, name="Bot"):
await channel.send(text)
@command(name='embed')
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None:
"""Send the input within an embed to either a specified channel or the current channel."""
embed = Embed(description=text)
diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py
index d9a7aafe1..236603dba 100644
--- a/bot/exts/utils/clean.py
+++ b/bot/exts/utils/clean.py
@@ -5,13 +5,12 @@ from typing import Iterable, Optional
from discord import Colour, Embed, Message, TextChannel, User
from discord.ext import commands
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
from bot.constants import (
Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES
)
-from bot.decorators import with_role
from bot.exts.moderation.modlog import ModLog
log = logging.getLogger(__name__)
@@ -192,13 +191,13 @@ class Clean(Cog):
)
@group(invoke_without_command=True, name="clean", aliases=["purge"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_group(self, ctx: Context) -> None:
"""Commands for cleaning messages in channels."""
await ctx.send_help(ctx.command)
@clean_group.command(name="user", aliases=["users"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_user(
self,
ctx: Context,
@@ -210,7 +209,7 @@ class Clean(Cog):
await self._clean_messages(amount, ctx, user=user, channels=channels)
@clean_group.command(name="all", aliases=["everything"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_all(
self,
ctx: Context,
@@ -221,7 +220,7 @@ class Clean(Cog):
await self._clean_messages(amount, ctx, channels=channels)
@clean_group.command(name="bots", aliases=["bot"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_bots(
self,
ctx: Context,
@@ -232,7 +231,7 @@ class Clean(Cog):
await self._clean_messages(amount, ctx, bots_only=True, channels=channels)
@clean_group.command(name="regex", aliases=["word", "expression"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_regex(
self,
ctx: Context,
@@ -244,7 +243,7 @@ class Clean(Cog):
await self._clean_messages(amount, ctx, regex=regex, channels=channels)
@clean_group.command(name="message", aliases=["messages"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_message(self, ctx: Context, message: Message) -> None:
"""Delete all messages until certain message, stop cleaning after hitting the `message`."""
await self._clean_messages(
@@ -255,7 +254,7 @@ class Clean(Cog):
)
@clean_group.command(name="stop", aliases=["cancel", "abort"])
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def clean_cancel(self, ctx: Context) -> None:
"""If there is an ongoing cleaning process, attempt to immediately cancel it."""
self.cleaning = False
diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py
index 23e5998d8..6419b320e 100644
--- a/bot/exts/utils/eval.py
+++ b/bot/exts/utils/eval.py
@@ -9,11 +9,10 @@ from io import StringIO
from typing import Any, Optional, Tuple
import discord
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, group, has_any_role
from bot.bot import Bot
from bot.constants import Roles
-from bot.decorators import with_role
from bot.interpreter import Interpreter
from bot.utils import find_nth_occurrence, send_to_paste_service
@@ -199,14 +198,14 @@ async def func(): # (None,) -> Any
await ctx.send(f"```py\n{out}```", embed=embed)
@group(name='internal', aliases=('int',))
- @with_role(Roles.owners, Roles.admins)
+ @has_any_role(Roles.owners, Roles.admins)
async def internal_group(self, ctx: Context) -> None:
"""Internal commands. Top secret!"""
if not ctx.invoked_subcommand:
await ctx.send_help(ctx.command)
@internal_group.command(name='eval', aliases=('e',))
- @with_role(Roles.admins, Roles.owners)
+ @has_any_role(Roles.admins, Roles.owners)
async def eval(self, ctx: Context, *, code: str) -> None:
"""Run eval in a REPL-like format."""
code = code.strip("`")
diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py
index 123f356e8..418db0150 100644
--- a/bot/exts/utils/extensions.py
+++ b/bot/exts/utils/extensions.py
@@ -11,7 +11,6 @@ from bot import exts
from bot.bot import Bot
from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
from bot.pagination import LinePaginator
-from bot.utils.checks import with_role_check
from bot.utils.extensions import EXTENSIONS, unqualify
log = logging.getLogger(__name__)
@@ -248,9 +247,9 @@ class Extensions(commands.Cog):
return msg, error_msg
# This cannot be static (must have a __func__ attribute).
- def cog_check(self, ctx: Context) -> bool:
+ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators and core developers to invoke the commands in this cog."""
- return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers)
+ return await commands.has_any_role(*MODERATION_ROLES, Roles.core_developers).predicate(ctx)
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
diff --git a/bot/exts/utils/jams.py b/bot/exts/utils/jams.py
index b3102db2f..1c0988343 100644
--- a/bot/exts/utils/jams.py
+++ b/bot/exts/utils/jams.py
@@ -7,7 +7,6 @@ from more_itertools import unique_everseen
from bot.bot import Bot
from bot.constants import Roles
-from bot.decorators import with_role
log = logging.getLogger(__name__)
@@ -22,7 +21,7 @@ class CodeJams(commands.Cog):
self.bot = bot
@commands.command()
- @with_role(Roles.admins)
+ @commands.has_any_role(Roles.admins)
async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None:
"""
Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team.
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index 08bce2153..6806f2889 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -15,7 +15,7 @@ from bot.bot import Bot
from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_ROLES
from bot.converters import Duration
from bot.pagination import LinePaginator
-from bot.utils.checks import with_role_check, without_role_check
+from bot.utils.checks import has_any_role_check, has_no_roles_check
from bot.utils.messages import send_denial
from bot.utils.scheduling import Scheduler
from bot.utils.time import humanize_delta
@@ -117,9 +117,9 @@ class Reminders(Cog):
If mentions aren't allowed, also return the type of mention(s) disallowed.
"""
- if without_role_check(ctx, *STAFF_ROLES):
+ if await has_no_roles_check(ctx, *STAFF_ROLES):
return False, "members/roles"
- elif without_role_check(ctx, *MODERATION_ROLES):
+ elif await has_no_roles_check(ctx, *MODERATION_ROLES):
return all(isinstance(mention, discord.Member) for mention in mentions), "roles"
else:
return True, ""
@@ -240,7 +240,7 @@ class Reminders(Cog):
Expiration is parsed per: http://strftime.org/
"""
# If the user is not staff, we need to verify whether or not to make a reminder at all.
- if without_role_check(ctx, *STAFF_ROLES):
+ if await has_no_roles_check(ctx, *STAFF_ROLES):
# If they don't have permission to set a reminder in this channel
if ctx.channel.id not in WHITELISTED_CHANNELS:
@@ -431,7 +431,7 @@ class Reminders(Cog):
The check passes when the user is an admin, or if they created the reminder.
"""
- if with_role_check(ctx, Roles.admins):
+ if await has_any_role_check(ctx, Roles.admins):
return True
api_response = await self.bot.api_client.get(f"bot/reminders/{reminder_id}")
diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py
index d96abbd5a..6b6941064 100644
--- a/bot/exts/utils/utils.py
+++ b/bot/exts/utils/utils.py
@@ -7,11 +7,11 @@ from io import StringIO
from typing import Tuple, Union
from discord import Colour, Embed, utils
-from discord.ext.commands import BadArgument, Cog, Context, clean_content, command
+from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role
from bot.bot import Bot
from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES
-from bot.decorators import in_whitelist, with_role
+from bot.decorators import in_whitelist
from bot.pagination import LinePaginator
from bot.utils import messages
@@ -224,7 +224,7 @@ class Utils(Cog):
await ctx.send(embed=embed)
@command(aliases=("poll",))
- @with_role(*MODERATION_ROLES)
+ @has_any_role(*MODERATION_ROLES)
async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None:
"""
Build a quick voting poll with matching reactions with the provided options.
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index f0ef36302..460a937d8 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -1,6 +1,6 @@
import datetime
import logging
-from typing import Callable, Container, Iterable, Optional
+from typing import Callable, Container, Iterable, Optional, Union
from discord.ext.commands import (
BucketType,
@@ -11,6 +11,8 @@ from discord.ext.commands import (
Context,
Cooldown,
CooldownMapping,
+ NoPrivateMessage,
+ has_any_role,
)
from bot import constants
@@ -89,35 +91,32 @@ def in_whitelist_check(
return False
-def with_role_check(ctx: Context, *role_ids: int) -> bool:
- """Returns True if the user has any one of the roles in role_ids."""
- if not ctx.guild: # Return False in a DM
- log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
- "This command is restricted by the with_role decorator. Rejecting request.")
- return False
+async def has_any_role_check(ctx: Context, *roles: Union[str, int]) -> bool:
+ """
+ Returns True if the context's author has any of the specified roles.
- for role in ctx.author.roles:
- if role.id in role_ids:
- log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
- return True
+ `roles` are the names or IDs of the roles for which to check.
+ False is always returns if the context is outside a guild.
+ """
+ try:
+ return await has_any_role(*roles).predicate(ctx)
+ except CheckFailure:
+ return False
- log.trace(f"{ctx.author} does not have the required role to use "
- f"the '{ctx.command.name}' command, so the request is rejected.")
- return False
+async def has_no_roles_check(ctx: Context, *roles: Union[str, int]) -> bool:
+ """
+ Returns True if the context's author doesn't have any of the specified roles.
-def without_role_check(ctx: Context, *role_ids: int) -> bool:
- """Returns True if the user does not have any of the roles in role_ids."""
- if not ctx.guild: # Return False in a DM
- log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
- "This command is restricted by the without_role decorator. Rejecting request.")
+ `roles` are the names or IDs of the roles for which to check.
+ False is always returns if the context is outside a guild.
+ """
+ try:
+ return not await has_any_role(*roles).predicate(ctx)
+ except NoPrivateMessage:
return False
-
- author_roles = [role.id for role in ctx.author.roles]
- check = all(role not in author_roles for role in role_ids)
- log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
- f"The result of the without_role check was {check}.")
- return check
+ except CheckFailure:
+ return True
def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 8c4fb764a..e2d44c637 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -253,9 +253,11 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
self.cog.cog_unload()
asyncio_mock.create_task.assert_not_called()
- @mock.patch("bot.exts.moderation.silence.with_role_check")
+ @mock.patch("discord.ext.commands.has_any_role")
@mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
- def test_cog_check(self, role_check):
+ async def test_cog_check(self, role_check):
"""Role check is called with `MODERATION_ROLES`"""
- self.cog.cog_check(self.ctx)
- role_check.assert_called_once_with(self.ctx, *(1, 2, 3))
+ role_check.return_value.predicate = mock.AsyncMock()
+ await self.cog.cog_check(self.ctx)
+ role_check.assert_called_once_with(*(1, 2, 3))
+ role_check.return_value.predicate.assert_awaited_once_with(self.ctx)
diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py
index e90394ab9..dad751e0d 100644
--- a/tests/bot/exts/moderation/test_slowmode.py
+++ b/tests/bot/exts/moderation/test_slowmode.py
@@ -103,9 +103,11 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase):
f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.'
)
- @mock.patch("bot.exts.moderation.slowmode.with_role_check")
+ @mock.patch("bot.exts.moderation.slowmode.has_any_role")
@mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3))
- def test_cog_check(self, role_check):
+ async def test_cog_check(self, role_check):
"""Role check is called with `MODERATION_ROLES`"""
- self.cog.cog_check(self.ctx)
- role_check.assert_called_once_with(self.ctx, *(1, 2, 3))
+ role_check.return_value.predicate = mock.AsyncMock()
+ await self.cog.cog_check(self.ctx)
+ role_check.assert_called_once_with(*(1, 2, 3))
+ role_check.return_value.predicate.assert_awaited_once_with(self.ctx)
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index de72e5748..883465e0b 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -1,48 +1,50 @@
import unittest
from unittest.mock import MagicMock
+from discord import DMChannel
+
from bot.utils import checks
from bot.utils.checks import InWhitelistCheckFailure
from tests.helpers import MockContext, MockRole
-class ChecksTests(unittest.TestCase):
+class ChecksTests(unittest.IsolatedAsyncioTestCase):
"""Tests the check functions defined in `bot.checks`."""
def setUp(self):
self.ctx = MockContext()
- def test_with_role_check_without_guild(self):
- """`with_role_check` returns `False` if `Context.guild` is None."""
- self.ctx.guild = None
- self.assertFalse(checks.with_role_check(self.ctx))
+ async def test_has_any_role_check_without_guild(self):
+ """`has_any_role_check` returns `False` for non-guild channels."""
+ self.ctx.channel = MagicMock(DMChannel)
+ self.assertFalse(await checks.has_any_role_check(self.ctx))
- def test_with_role_check_without_required_roles(self):
- """`with_role_check` returns `False` if `Context.author` lacks the required role."""
+ async def test_has_any_role_check_without_required_roles(self):
+ """`has_any_role_check` returns `False` if `Context.author` lacks the required role."""
self.ctx.author.roles = []
- self.assertFalse(checks.with_role_check(self.ctx))
+ self.assertFalse(await checks.has_any_role_check(self.ctx))
- def test_with_role_check_with_guild_and_required_role(self):
- """`with_role_check` returns `True` if `Context.author` has the required role."""
+ async def test_has_any_role_check_with_guild_and_required_role(self):
+ """`has_any_role_check` returns `True` if `Context.author` has the required role."""
self.ctx.author.roles.append(MockRole(id=10))
- self.assertTrue(checks.with_role_check(self.ctx, 10))
+ self.assertTrue(await checks.has_any_role_check(self.ctx, 10))
- def test_without_role_check_without_guild(self):
- """`without_role_check` should return `False` when `Context.guild` is None."""
- self.ctx.guild = None
- self.assertFalse(checks.without_role_check(self.ctx))
+ async def test_has_no_roles_check_without_guild(self):
+ """`has_no_roles_check` should return `False` when `Context.guild` is None."""
+ self.ctx.channel = MagicMock(DMChannel)
+ self.assertFalse(await checks.has_no_roles_check(self.ctx))
- def test_without_role_check_returns_false_with_unwanted_role(self):
- """`without_role_check` returns `False` if `Context.author` has unwanted role."""
+ async def test_has_no_roles_check_returns_false_with_unwanted_role(self):
+ """`has_no_roles_check` returns `False` if `Context.author` has unwanted role."""
role_id = 42
self.ctx.author.roles.append(MockRole(id=role_id))
- self.assertFalse(checks.without_role_check(self.ctx, role_id))
+ self.assertFalse(await checks.has_no_roles_check(self.ctx, role_id))
- def test_without_role_check_returns_true_without_unwanted_role(self):
- """`without_role_check` returns `True` if `Context.author` does not have unwanted role."""
+ async def test_has_no_roles_check_returns_true_without_unwanted_role(self):
+ """`has_no_roles_check` returns `True` if `Context.author` does not have unwanted role."""
role_id = 42
self.ctx.author.roles.append(MockRole(id=role_id))
- self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
+ self.assertTrue(await checks.has_no_roles_check(self.ctx, role_id + 10))
def test_in_whitelist_check_correct_channel(self):
"""`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list."""