diff options
| author | 2020-09-21 15:40:00 -0700 | |
|---|---|---|
| committer | 2020-09-21 15:40:00 -0700 | |
| commit | 10e6a708d7d32c67dc924684ffda0a75913e915e (patch) | |
| tree | 105f2ae510b04ba12372d88475b4bf3bdb58892e | |
| parent | Merge PR #817: Write tests for moderation utils (diff) | |
| parent | Fix conflicts caused by #1103 (diff) | |
Merge pull request #1150 from python-discord/feat/backend/217/has_any_role
Implement with(out)_role checks and decorators using has_any_role
28 files changed, 173 insertions, 180 deletions
| diff --git a/bot/decorators.py b/bot/decorators.py index 500197c89..2518124da 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -6,13 +6,12 @@ from functools import wraps  from typing import Callable, Container, Optional, Union  from weakref import WeakValueDictionary -from discord import Colour, Embed, Member -from discord.errors import NotFound +from discord import Colour, Embed, Member, NotFound  from discord.ext import commands  from discord.ext.commands import Cog, Context  from bot.constants import Channels, ERROR_REPLIES, RedirectOutput -from bot.utils.checks import in_whitelist_check, with_role_check, without_role_check +from bot.utils.checks import in_whitelist_check  log = logging.getLogger(__name__) @@ -45,18 +44,22 @@ def in_whitelist(      return commands.check(predicate) -def with_role(*role_ids: int) -> Callable: -    """Returns True if the user has any one of the roles in role_ids.""" -    async def predicate(ctx: Context) -> bool: -        """With role checker predicate.""" -        return with_role_check(ctx, *role_ids) -    return commands.check(predicate) - +def has_no_roles(*roles: Union[str, int]) -> Callable: +    """ +    Returns True if the user does not have any of the roles specified. -def without_role(*role_ids: int) -> Callable: -    """Returns True if the user does not have any of the roles in role_ids.""" +    `roles` are the names or IDs of the disallowed roles. +    """      async def predicate(ctx: Context) -> bool: -        return without_role_check(ctx, *role_ids) +        try: +            await commands.has_any_role(*roles).predicate(ctx) +        except commands.MissingAnyRole: +            return True +        else: +            # This error is never shown to users, so don't bother trying to make it too pretty. +            roles_ = ", ".join(f"'{item}'" for item in roles) +            raise commands.CheckFailure(f"You have at least one of the disallowed roles: {roles_}") +      return commands.check(predicate) diff --git a/bot/exts/filters/filter_lists.py b/bot/exts/filters/filter_lists.py index c15adc461..232c1e48b 100644 --- a/bot/exts/filters/filter_lists.py +++ b/bot/exts/filters/filter_lists.py @@ -2,14 +2,13 @@ import logging  from typing import Optional  from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group, has_any_role  from bot import constants  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.converters import ValidDiscordServerInvite, ValidFilterListType  from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check  log = logging.getLogger(__name__) @@ -263,9 +262,9 @@ class FilterLists(Cog):          """Syncs both allowlists and denylists with the API."""          await self._sync_data(ctx) -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *constants.MODERATION_ROLES) +        return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)  def setup(bot: Bot) -> None: diff --git a/bot/exts/fun/off_topic_names.py b/bot/exts/fun/off_topic_names.py index ce95450e0..b9d235fa2 100644 --- a/bot/exts/fun/off_topic_names.py +++ b/bot/exts/fun/off_topic_names.py @@ -4,13 +4,12 @@ import logging  from datetime import datetime, timedelta  from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES  from bot.converters import OffTopicName -from bot.decorators import with_role  from bot.pagination import LinePaginator  CHANNELS = (Channels.off_topic_0, Channels.off_topic_1, Channels.off_topic_2) @@ -67,13 +66,13 @@ class OffTopicNames(Cog):              self.updater_task = self.bot.loop.create_task(coro)      @group(name='otname', aliases=('otnames', 'otn'), invoke_without_command=True) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def otname_group(self, ctx: Context) -> None:          """Add or list items from the off-topic channel name rotation."""          await ctx.send_help(ctx.command)      @otname_group.command(name='add', aliases=('a',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def add_command(self, ctx: Context, *, name: OffTopicName) -> None:          """          Adds a new off-topic name to the rotation. @@ -96,7 +95,7 @@ class OffTopicNames(Cog):              await self._add_name(ctx, name)      @otname_group.command(name='forceadd', aliases=('fa',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None:          """Forcefully adds a new off-topic name to the rotation."""          await self._add_name(ctx, name) @@ -109,7 +108,7 @@ class OffTopicNames(Cog):          await ctx.send(f":ok_hand: Added `{name}` to the names list.")      @otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd')) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None:          """Removes a off-topic name from the rotation."""          await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}') @@ -118,7 +117,7 @@ class OffTopicNames(Cog):          await ctx.send(f":ok_hand: Removed `{name}` from the names list.")      @otname_group.command(name='list', aliases=('l',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def list_command(self, ctx: Context) -> None:          """          Lists all currently known off-topic channel names in a paginator. @@ -138,7 +137,7 @@ class OffTopicNames(Cog):              await ctx.send(embed=embed)      @otname_group.command(name='search', aliases=('s',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def search_command(self, ctx: Context, *, query: OffTopicName) -> None:          """Search for an off-topic name."""          result = await self.bot.api_client.get('bot/off-topic-channel-names') diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index 0f9cac89e..17142071f 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -14,7 +14,6 @@ from discord.ext import commands  from bot import constants  from bot.bot import Bot  from bot.utils import RedisCache -from bot.utils.checks import with_role_check  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) @@ -196,12 +195,12 @@ class HelpChannels(commands.Cog):              return True          log.trace(f"{ctx.author} is not the help channel claimant, checking roles.") -        role_check = with_role_check(ctx, *constants.HelpChannels.cmd_whitelist) +        has_role = await commands.has_any_role(*constants.HelpChannels.cmd_whitelist).predicate(ctx) -        if role_check: +        if has_role:              self.bot.stats.incr("help.dormant_invoke.staff") -        return role_check +        return has_role      @commands.command(name="close", aliases=["dormant", "solved"], enabled=False)      async def close_command(self, ctx: commands.Context) -> None: diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py index 30c793c75..e50b9b32b 100644 --- a/bot/exts/info/doc.py +++ b/bot/exts/info/doc.py @@ -21,7 +21,6 @@ from urllib3.exceptions import ProtocolError  from bot.bot import Bot  from bot.constants import MODERATION_ROLES, RedirectOutput  from bot.converters import ValidPythonIdentifier, ValidURL -from bot.decorators import with_role  from bot.pagination import LinePaginator  from bot.utils.messages import wait_for_deletion @@ -396,7 +395,7 @@ class Doc(commands.Cog):                  await wait_for_deletion(msg, (ctx.author.id,), client=self.bot)      @docs_group.command(name='set', aliases=('s',)) -    @with_role(*MODERATION_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def set_command(          self, ctx: commands.Context, package_name: ValidPythonIdentifier,          base_url: ValidURL, inventory_url: InventoryURL @@ -433,7 +432,7 @@ class Doc(commands.Cog):          await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.")      @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) -    @with_role(*MODERATION_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None:          """          Removes the specified package from the database. @@ -450,7 +449,7 @@ class Doc(commands.Cog):          await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.")      @docs_group.command(name="refresh", aliases=("rfsh", "r")) -    @with_role(*MODERATION_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def refresh_command(self, ctx: commands.Context) -> None:          """Refresh inventories and send differences to channel."""          old_inventories = set(self.base_urls) diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 55ecb2836..581b3a227 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -8,18 +8,19 @@ from typing import Any, Mapping, Optional, Tuple, Union  from discord import ChannelType, Colour, CustomActivity, Embed, Guild, Member, Message, Role, Status, utils  from discord.abc import GuildChannel -from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role  from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.decorators import in_whitelist, with_role +from bot.decorators import in_whitelist  from bot.pagination import LinePaginator -from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, with_role_check +from bot.utils.checks import InWhitelistCheckFailure, cooldown_with_role_bypass, has_no_roles_check  from bot.utils.time import time_since  log = logging.getLogger(__name__) +  STATUS_EMOTES = {      Status.offline: constants.Emojis.status_offline,      Status.dnd: constants.Emojis.status_dnd, @@ -76,7 +77,7 @@ class Information(Cog):          channel_type_list = sorted(channel_type_list)          return "\n".join(channel_type_list) -    @with_role(*constants.MODERATION_ROLES) +    @has_any_role(*constants.MODERATION_ROLES)      @command(name="roles")      async def roles_info(self, ctx: Context) -> None:          """Returns a list of all roles and their corresponding IDs.""" @@ -96,7 +97,7 @@ class Information(Cog):          await LinePaginator.paginate(role_list, ctx, embed, empty=False) -    @with_role(*constants.MODERATION_ROLES) +    @has_any_role(*constants.MODERATION_ROLES)      @command(name="role")      async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None:          """ @@ -197,12 +198,12 @@ class Information(Cog):              user = ctx.author          # Do a role check if this is being executed on someone other than the caller -        elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): +        elif user != ctx.author and await has_no_roles_check(ctx, *constants.MODERATION_ROLES):              await ctx.send("You may not use this command on users other than yourself.")              return          # Non-staff may only do this in #bot-commands -        if not with_role_check(ctx, *constants.STAFF_ROLES): +        if await has_no_roles_check(ctx, *constants.STAFF_ROLES):              if not ctx.channel.id == constants.Channels.bot_commands:                  raise InWhitelistCheckFailure(constants.Channels.bot_commands) diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py index 5d9e2c20b..635162308 100644 --- a/bot/exts/info/reddit.py +++ b/bot/exts/info/reddit.py @@ -8,14 +8,13 @@ from typing import List  from aiohttp import BasicAuth, ClientError  from discord import Colour, Embed, TextChannel -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from discord.ext.tasks import loop  from discord.utils import escape_markdown  from bot.bot import Bot  from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks  from bot.converters import Subreddit -from bot.decorators import with_role  from bot.pagination import LinePaginator  from bot.utils.messages import sub_clyde @@ -282,7 +281,7 @@ class Reddit(Cog):          await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed) -    @with_role(*STAFF_ROLES) +    @has_any_role(*STAFF_ROLES)      @reddit_group.command(name="subreddits", aliases=("subs",))      async def subreddits_command(self, ctx: Context) -> None:          """Send a paginated embed of all the subreddits we're relaying.""" diff --git a/bot/exts/moderation/defcon.py b/bot/exts/moderation/defcon.py index 6e4008777..3bf462877 100644 --- a/bot/exts/moderation/defcon.py +++ b/bot/exts/moderation/defcon.py @@ -6,11 +6,10 @@ from datetime import datetime, timedelta  from enum import Enum  from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import Channels, Colours, Emojis, Event, Icons, MODERATION_ROLES, Roles -from bot.decorators import with_role  from bot.exts.moderation.modlog import ModLog  log = logging.getLogger(__name__) @@ -119,7 +118,7 @@ class Defcon(Cog):                  )      @group(name='defcon', aliases=('dc',), invoke_without_command=True) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def defcon_group(self, ctx: Context) -> None:          """Check the DEFCON status or run a subcommand."""          await ctx.send_help(ctx.command) @@ -163,7 +162,7 @@ class Defcon(Cog):              self.bot.stats.gauge("defcon.threshold", days)      @defcon_group.command(name='enable', aliases=('on', 'e'), root_aliases=("defon",)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def enable_command(self, ctx: Context) -> None:          """          Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! @@ -176,7 +175,7 @@ class Defcon(Cog):          await self.update_channel_topic()      @defcon_group.command(name='disable', aliases=('off', 'd'), root_aliases=("defoff",)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def disable_command(self, ctx: Context) -> None:          """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!"""          self.enabled = False @@ -184,7 +183,7 @@ class Defcon(Cog):          await self.update_channel_topic()      @defcon_group.command(name='status', aliases=('s',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def status_command(self, ctx: Context) -> None:          """Check the current status of DEFCON mode."""          embed = Embed( @@ -196,7 +195,7 @@ class Defcon(Cog):          await ctx.send(embed=embed)      @defcon_group.command(name='days') -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def days_command(self, ctx: Context, days: int) -> None:          """Set how old an account must be to join the server, in days, with DEFCON mode enabled."""          self.days = timedelta(days=days) diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 0d8f340b4..7a3fe49bb 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -10,7 +10,7 @@ from bot import constants  from bot.bot import Bot  from bot.converters import UserMentionOrID  from bot.utils import RedisCache -from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.checks import in_whitelist_check  from bot.utils.messages import send_attachments  from bot.utils.webhooks import send_webhook @@ -105,10 +105,10 @@ class DMRelay(Cog):              except discord.HTTPException:                  log.exception("Failed to send an attachment to the webhook") -    def cog_check(self, ctx: commands.Context) -> bool: +    async def cog_check(self, ctx: commands.Context) -> bool:          """Only allow moderators to invoke the commands in this cog."""          checks = [ -            with_role_check(ctx, *constants.MODERATION_ROLES), +            await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),              in_whitelist_check(                  ctx,                  channels=[constants.Channels.dm_log], diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 84ea47371..5fa62d3c4 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -15,7 +15,6 @@ from bot.decorators import respect_role_hierarchy  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.exts.moderation.infraction._utils import UserSnowflake -from bot.utils.checks import with_role_check  log = logging.getLogger(__name__) @@ -357,9 +356,9 @@ class Infractions(InfractionScheduler, commands.Cog):      # endregion      # This cannot be static (must have a __func__ attribute). -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *constants.MODERATION_ROLES) +        return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx)      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 5875abd26..15ee28537 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog  from bot.pagination import LinePaginator  from bot.utils import time -from bot.utils.checks import in_whitelist_check, with_role_check +from bot.utils.checks import in_whitelist_check  log = logging.getLogger(__name__) @@ -282,10 +282,10 @@ class ModManagement(commands.Cog):      # endregion      # This cannot be static (must have a __func__ attribute). -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators inside moderator channels to invoke the commands in this cog."""          checks = [ -            with_role_check(ctx, *constants.MODERATION_ROLES), +            await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),              in_whitelist_check(                  ctx,                  channels=constants.MODERATION_CHANNELS, diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index a4e78c4d3..29f41f2ab 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -6,14 +6,13 @@ import typing as t  from pathlib import Path  from discord import Colour, Embed, Member -from discord.ext.commands import Cog, Context, command +from discord.ext.commands import Cog, Context, command, has_any_role  from bot import constants  from bot.bot import Bot  from bot.converters import Expiry  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler -from bot.utils.checks import with_role_check  from bot.utils.time import format_infraction  log = logging.getLogger(__name__) @@ -234,9 +233,9 @@ class Superstarify(InfractionScheduler, Cog):          return rng.choice(STAR_NAMES)      # This cannot be static (must have a __func__ attribute). -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *constants.MODERATION_ROLES) +        return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)  def setup(bot: Bot) -> None: diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 4af87c724..ac0c1c85e 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -10,7 +10,6 @@ from discord.ext.commands import Context  from bot.bot import Bot  from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles  from bot.converters import HushDurationConverter -from bot.utils.checks import with_role_check  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) @@ -160,9 +159,9 @@ class Silence(commands.Cog):              asyncio.create_task(self._mod_alerts_channel.send(message))      # This cannot be static (must have a __func__ attribute). -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *MODERATION_ROLES) +        return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx)  def setup(bot: Bot) -> None: diff --git a/bot/exts/moderation/slowmode.py b/bot/exts/moderation/slowmode.py index 1d055afac..efd862aa5 100644 --- a/bot/exts/moderation/slowmode.py +++ b/bot/exts/moderation/slowmode.py @@ -4,12 +4,11 @@ from typing import Optional  from dateutil.relativedelta import relativedelta  from discord import TextChannel -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES  from bot.converters import DurationDelta -from bot.decorators import with_role_check  from bot.utils import time  log = logging.getLogger(__name__) @@ -87,9 +86,9 @@ class Slowmode(Cog):              f'{Emojis.check_mark} The slowmode delay for {channel.mention} has been reset to 0 seconds.'          ) -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *MODERATION_ROLES) +        return await has_any_role(*MODERATION_ROLES).predicate(ctx)  def setup(bot: Bot) -> None: diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index 53fa0730b..8ec68ac1e 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -6,14 +6,14 @@ from datetime import datetime, timedelta  import discord  from discord.ext import tasks -from discord.ext.commands import Cog, Context, command, group +from discord.ext.commands import Cog, Context, command, group, has_any_role  from discord.utils import snowflake_time  from bot import constants  from bot.bot import Bot -from bot.decorators import in_whitelist, with_role, without_role +from bot.decorators import has_no_roles, in_whitelist  from bot.exts.moderation.modlog import ModLog -from bot.utils.checks import InWhitelistCheckFailure, without_role_check +from bot.utils.checks import InWhitelistCheckFailure, has_no_roles_check  from bot.utils.redis_cache import RedisCache  log = logging.getLogger(__name__) @@ -568,7 +568,7 @@ class Verification(Cog):      # endregion      # region: task management commands -    @with_role(*constants.MODERATION_ROLES) +    @has_any_role(*constants.MODERATION_ROLES)      @group(name="verification")      async def verification_group(self, ctx: Context) -> None:          """Manage internal verification tasks.""" @@ -653,7 +653,7 @@ class Verification(Cog):          self.bot.stats.incr(f"verification.{category}")      @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) -    @without_role(constants.Roles.verified) +    @has_no_roles(constants.Roles.verified)      @in_whitelist(channels=(constants.Channels.verification,))      async def accept_command(self, ctx: Context, *_) -> None:  # We don't actually care about the args          """Accept our rules and gain access to the rest of the server.""" @@ -736,9 +736,10 @@ class Verification(Cog):              error.handled = True      @staticmethod -    def bot_check(ctx: Context) -> bool: +    async def bot_check(ctx: Context) -> bool:          """Block any command within the verification channel that is not !accept.""" -        if ctx.channel.id == constants.Channels.verification and without_role_check(ctx, *constants.MODERATION_ROLES): +        is_verification = ctx.channel.id == constants.Channels.verification +        if is_verification and await has_no_roles_check(ctx, *constants.MODERATION_ROLES):              return ctx.command.name == "accept"          else:              return True diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py index d7127b5c4..3b44056d3 100644 --- a/bot/exts/moderation/watchchannels/bigbrother.py +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -2,12 +2,11 @@ import logging  import textwrap  from collections import ChainMap -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES, Webhooks  from bot.converters import FetchedMember -from bot.decorators import with_role  from bot.exts.moderation.infraction._utils import post_infraction  from bot.exts.moderation.watchchannels._watchchannel import WatchChannel @@ -28,13 +27,13 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):          )      @group(name='bigbrother', aliases=('bb',), invoke_without_command=True) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def bigbrother_group(self, ctx: Context) -> None:          """Monitors users by relaying their messages to the Big Brother watch channel."""          await ctx.send_help(ctx.command)      @bigbrother_group.command(name='watched', aliases=('all', 'list')) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def watched_command(          self, ctx: Context, oldest_first: bool = False, update_cache: bool = True      ) -> None: @@ -49,7 +48,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):          await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache)      @bigbrother_group.command(name='oldest') -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None:          """          Shows Big Brother monitored users ordered by oldest watched. @@ -60,7 +59,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):          await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache)      @bigbrother_group.command(name='watch', aliases=('w',), root_aliases=('watch',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:          """          Relay messages sent by the given `user` to the `#big-brother` channel. @@ -71,7 +70,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):          await self.apply_watch(ctx, user, reason)      @bigbrother_group.command(name='unwatch', aliases=('uw',), root_aliases=('unwatch',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:          """Stop relaying messages by the given `user`."""          await self.apply_unwatch(ctx, user, reason) diff --git a/bot/exts/moderation/watchchannels/talentpool.py b/bot/exts/moderation/watchchannels/talentpool.py index 3724e94e6..a77dbe156 100644 --- a/bot/exts/moderation/watchchannels/talentpool.py +++ b/bot/exts/moderation/watchchannels/talentpool.py @@ -4,13 +4,12 @@ from collections import ChainMap  from typing import Union  from discord import Color, Embed, Member, User -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks  from bot.converters import FetchedMember -from bot.decorators import with_role  from bot.exts.moderation.watchchannels._watchchannel import WatchChannel  from bot.pagination import LinePaginator  from bot.utils import time @@ -32,13 +31,13 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          )      @group(name='talentpool', aliases=('tp', 'talent', 'nomination', 'n'), invoke_without_command=True) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def nomination_group(self, ctx: Context) -> None:          """Highlights the activity of helper nominees by relaying their messages to the talent pool channel."""          await ctx.send_help(ctx.command)      @nomination_group.command(name='watched', aliases=('all', 'list'), root_aliases=("nominees",)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def watched_command(          self, ctx: Context, oldest_first: bool = False, update_cache: bool = True      ) -> None: @@ -53,7 +52,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          await self.list_watched_users(ctx, oldest_first=oldest_first, update_cache=update_cache)      @nomination_group.command(name='oldest') -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def oldest_command(self, ctx: Context, update_cache: bool = True) -> None:          """          Shows talent pool monitored users ordered by oldest nomination. @@ -64,7 +63,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          await ctx.invoke(self.watched_command, oldest_first=True, update_cache=update_cache)      @nomination_group.command(name='watch', aliases=('w', 'add', 'a'), root_aliases=("nominate",)) -    @with_role(*STAFF_ROLES) +    @has_any_role(*STAFF_ROLES)      async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:          """          Relay messages sent by the given `user` to the `#talent-pool` channel. @@ -129,7 +128,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          await ctx.send(msg)      @nomination_group.command(name='history', aliases=('info', 'search')) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def history_command(self, ctx: Context, user: FetchedMember) -> None:          """Shows the specified user's nomination history."""          result = await self.bot.api_client.get( @@ -158,7 +157,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          )      @nomination_group.command(name='unwatch', aliases=('end', ), root_aliases=("unnominate",)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None:          """          Ends the active nomination of the specified user with the given reason. @@ -171,13 +170,13 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):              await ctx.send(":x: The specified user does not have an active nomination")      @nomination_group.group(name='edit', aliases=('e',), invoke_without_command=True) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def nomination_edit_group(self, ctx: Context) -> None:          """Commands to edit nominations."""          await ctx.send_help(ctx.command)      @nomination_edit_group.command(name='reason') -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def edit_reason_command(self, ctx: Context, nomination_id: int, *, reason: str) -> None:          """          Edits the reason/unnominate reason for the nomination with the given `id` depending on the status. diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index 66f340a99..7ed487d47 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -5,11 +5,10 @@ import time  from typing import Optional, Tuple  from discord import Embed, Message, RawMessageUpdateEvent, TextChannel -from discord.ext.commands import Cog, Context, command, group +from discord.ext.commands import Cog, Context, command, group, has_any_role  from bot.bot import Bot  from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.decorators import with_role  from bot.exts.filters.token_remover import TokenRemover  from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE  from bot.utils.messages import wait_for_deletion @@ -39,13 +38,13 @@ class BotCog(Cog, name="Bot"):          self.codeblock_message_ids = {}      @group(invoke_without_command=True, name="bot", hidden=True) -    @with_role(Roles.verified) +    @has_any_role(Roles.verified)      async def botinfo_group(self, ctx: Context) -> None:          """Bot informational commands."""          await ctx.send_help(ctx.command)      @botinfo_group.command(name='about', aliases=('info',), hidden=True) -    @with_role(Roles.verified) +    @has_any_role(Roles.verified)      async def about_command(self, ctx: Context) -> None:          """Get information about the bot."""          embed = Embed( @@ -63,7 +62,7 @@ class BotCog(Cog, name="Bot"):          await ctx.send(embed=embed)      @command(name='echo', aliases=('print',)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None:          """Repeat the given message in either a specified channel or the current channel."""          if channel is None: @@ -72,7 +71,7 @@ class BotCog(Cog, name="Bot"):              await channel.send(text)      @command(name='embed') -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def embed_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None:          """Send the input within an embed to either a specified channel or the current channel."""          embed = Embed(description=text) diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py index d9a7aafe1..236603dba 100644 --- a/bot/exts/utils/clean.py +++ b/bot/exts/utils/clean.py @@ -5,13 +5,12 @@ from typing import Iterable, Optional  from discord import Colour, Embed, Message, TextChannel, User  from discord.ext import commands -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import (      Channels, CleanMessages, Colours, Event, Icons, MODERATION_ROLES, NEGATIVE_REPLIES  ) -from bot.decorators import with_role  from bot.exts.moderation.modlog import ModLog  log = logging.getLogger(__name__) @@ -192,13 +191,13 @@ class Clean(Cog):          )      @group(invoke_without_command=True, name="clean", aliases=["purge"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_group(self, ctx: Context) -> None:          """Commands for cleaning messages in channels."""          await ctx.send_help(ctx.command)      @clean_group.command(name="user", aliases=["users"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_user(          self,          ctx: Context, @@ -210,7 +209,7 @@ class Clean(Cog):          await self._clean_messages(amount, ctx, user=user, channels=channels)      @clean_group.command(name="all", aliases=["everything"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_all(          self,          ctx: Context, @@ -221,7 +220,7 @@ class Clean(Cog):          await self._clean_messages(amount, ctx, channels=channels)      @clean_group.command(name="bots", aliases=["bot"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_bots(          self,          ctx: Context, @@ -232,7 +231,7 @@ class Clean(Cog):          await self._clean_messages(amount, ctx, bots_only=True, channels=channels)      @clean_group.command(name="regex", aliases=["word", "expression"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_regex(          self,          ctx: Context, @@ -244,7 +243,7 @@ class Clean(Cog):          await self._clean_messages(amount, ctx, regex=regex, channels=channels)      @clean_group.command(name="message", aliases=["messages"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_message(self, ctx: Context, message: Message) -> None:          """Delete all messages until certain message, stop cleaning after hitting the `message`."""          await self._clean_messages( @@ -255,7 +254,7 @@ class Clean(Cog):          )      @clean_group.command(name="stop", aliases=["cancel", "abort"]) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def clean_cancel(self, ctx: Context) -> None:          """If there is an ongoing cleaning process, attempt to immediately cancel it."""          self.cleaning = False diff --git a/bot/exts/utils/eval.py b/bot/exts/utils/eval.py index 23e5998d8..6419b320e 100644 --- a/bot/exts/utils/eval.py +++ b/bot/exts/utils/eval.py @@ -9,11 +9,10 @@ from io import StringIO  from typing import Any, Optional, Tuple  import discord -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import Roles -from bot.decorators import with_role  from bot.interpreter import Interpreter  from bot.utils import find_nth_occurrence, send_to_paste_service @@ -199,14 +198,14 @@ async def func():  # (None,) -> Any          await ctx.send(f"```py\n{out}```", embed=embed)      @group(name='internal', aliases=('int',)) -    @with_role(Roles.owners, Roles.admins) +    @has_any_role(Roles.owners, Roles.admins)      async def internal_group(self, ctx: Context) -> None:          """Internal commands. Top secret!"""          if not ctx.invoked_subcommand:              await ctx.send_help(ctx.command)      @internal_group.command(name='eval', aliases=('e',)) -    @with_role(Roles.admins, Roles.owners) +    @has_any_role(Roles.admins, Roles.owners)      async def eval(self, ctx: Context, *, code: str) -> None:          """Run eval in a REPL-like format."""          code = code.strip("`") diff --git a/bot/exts/utils/extensions.py b/bot/exts/utils/extensions.py index 123f356e8..418db0150 100644 --- a/bot/exts/utils/extensions.py +++ b/bot/exts/utils/extensions.py @@ -11,7 +11,6 @@ from bot import exts  from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs  from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check  from bot.utils.extensions import EXTENSIONS, unqualify  log = logging.getLogger(__name__) @@ -248,9 +247,9 @@ class Extensions(commands.Cog):          return msg, error_msg      # This cannot be static (must have a __func__ attribute). -    def cog_check(self, ctx: Context) -> bool: +    async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators and core developers to invoke the commands in this cog.""" -        return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) +        return await commands.has_any_role(*MODERATION_ROLES, Roles.core_developers).predicate(ctx)      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/exts/utils/jams.py b/bot/exts/utils/jams.py index b3102db2f..1c0988343 100644 --- a/bot/exts/utils/jams.py +++ b/bot/exts/utils/jams.py @@ -7,7 +7,6 @@ from more_itertools import unique_everseen  from bot.bot import Bot  from bot.constants import Roles -from bot.decorators import with_role  log = logging.getLogger(__name__) @@ -22,7 +21,7 @@ class CodeJams(commands.Cog):          self.bot = bot      @commands.command() -    @with_role(Roles.admins) +    @commands.has_any_role(Roles.admins)      async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None:          """          Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 08bce2153..6806f2889 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -15,7 +15,7 @@ from bot.bot import Bot  from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_ROLES  from bot.converters import Duration  from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check, without_role_check +from bot.utils.checks import has_any_role_check, has_no_roles_check  from bot.utils.messages import send_denial  from bot.utils.scheduling import Scheduler  from bot.utils.time import humanize_delta @@ -117,9 +117,9 @@ class Reminders(Cog):          If mentions aren't allowed, also return the type of mention(s) disallowed.          """ -        if without_role_check(ctx, *STAFF_ROLES): +        if await has_no_roles_check(ctx, *STAFF_ROLES):              return False, "members/roles" -        elif without_role_check(ctx, *MODERATION_ROLES): +        elif await has_no_roles_check(ctx, *MODERATION_ROLES):              return all(isinstance(mention, discord.Member) for mention in mentions), "roles"          else:              return True, "" @@ -240,7 +240,7 @@ class Reminders(Cog):          Expiration is parsed per: http://strftime.org/          """          # If the user is not staff, we need to verify whether or not to make a reminder at all. -        if without_role_check(ctx, *STAFF_ROLES): +        if await has_no_roles_check(ctx, *STAFF_ROLES):              # If they don't have permission to set a reminder in this channel              if ctx.channel.id not in WHITELISTED_CHANNELS: @@ -431,7 +431,7 @@ class Reminders(Cog):          The check passes when the user is an admin, or if they created the reminder.          """ -        if with_role_check(ctx, Roles.admins): +        if await has_any_role_check(ctx, Roles.admins):              return True          api_response = await self.bot.api_client.get(f"bot/reminders/{reminder_id}") diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index d96abbd5a..6b6941064 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -7,11 +7,11 @@ from io import StringIO  from typing import Tuple, Union  from discord import Colour, Embed, utils -from discord.ext.commands import BadArgument, Cog, Context, clean_content, command +from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role  from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES -from bot.decorators import in_whitelist, with_role +from bot.decorators import in_whitelist  from bot.pagination import LinePaginator  from bot.utils import messages @@ -224,7 +224,7 @@ class Utils(Cog):          await ctx.send(embed=embed)      @command(aliases=("poll",)) -    @with_role(*MODERATION_ROLES) +    @has_any_role(*MODERATION_ROLES)      async def vote(self, ctx: Context, title: clean_content(fix_channel_mentions=True), *options: str) -> None:          """          Build a quick voting poll with matching reactions with the provided options. diff --git a/bot/utils/checks.py b/bot/utils/checks.py index f0ef36302..460a937d8 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,6 +1,6 @@  import datetime  import logging -from typing import Callable, Container, Iterable, Optional +from typing import Callable, Container, Iterable, Optional, Union  from discord.ext.commands import (      BucketType, @@ -11,6 +11,8 @@ from discord.ext.commands import (      Context,      Cooldown,      CooldownMapping, +    NoPrivateMessage, +    has_any_role,  )  from bot import constants @@ -89,35 +91,32 @@ def in_whitelist_check(      return False -def with_role_check(ctx: Context, *role_ids: int) -> bool: -    """Returns True if the user has any one of the roles in role_ids.""" -    if not ctx.guild:  # Return False in a DM -        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " -                  "This command is restricted by the with_role decorator. Rejecting request.") -        return False +async def has_any_role_check(ctx: Context, *roles: Union[str, int]) -> bool: +    """ +    Returns True if the context's author has any of the specified roles. -    for role in ctx.author.roles: -        if role.id in role_ids: -            log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.") -            return True +    `roles` are the names or IDs of the roles for which to check. +    False is always returns if the context is outside a guild. +    """ +    try: +        return await has_any_role(*roles).predicate(ctx) +    except CheckFailure: +        return False -    log.trace(f"{ctx.author} does not have the required role to use " -              f"the '{ctx.command.name}' command, so the request is rejected.") -    return False +async def has_no_roles_check(ctx: Context, *roles: Union[str, int]) -> bool: +    """ +    Returns True if the context's author doesn't have any of the specified roles. -def without_role_check(ctx: Context, *role_ids: int) -> bool: -    """Returns True if the user does not have any of the roles in role_ids.""" -    if not ctx.guild:  # Return False in a DM -        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " -                  "This command is restricted by the without_role decorator. Rejecting request.") +    `roles` are the names or IDs of the roles for which to check. +    False is always returns if the context is outside a guild. +    """ +    try: +        return not await has_any_role(*roles).predicate(ctx) +    except NoPrivateMessage:          return False - -    author_roles = [role.id for role in ctx.author.roles] -    check = all(role not in author_roles for role in role_ids) -    log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -              f"The result of the without_role check was {check}.") -    return check +    except CheckFailure: +        return True  def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *, diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 8c4fb764a..e2d44c637 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -253,9 +253,11 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          self.cog.cog_unload()          asyncio_mock.create_task.assert_not_called() -    @mock.patch("bot.exts.moderation.silence.with_role_check") +    @mock.patch("discord.ext.commands.has_any_role")      @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) -    def test_cog_check(self, role_check): +    async def test_cog_check(self, role_check):          """Role check is called with `MODERATION_ROLES`""" -        self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) +        role_check.return_value.predicate = mock.AsyncMock() +        await self.cog.cog_check(self.ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) diff --git a/tests/bot/exts/moderation/test_slowmode.py b/tests/bot/exts/moderation/test_slowmode.py index e90394ab9..dad751e0d 100644 --- a/tests/bot/exts/moderation/test_slowmode.py +++ b/tests/bot/exts/moderation/test_slowmode.py @@ -103,9 +103,11 @@ class SlowmodeTests(unittest.IsolatedAsyncioTestCase):              f'{Emojis.check_mark} The slowmode delay for #meta has been reset to 0 seconds.'          ) -    @mock.patch("bot.exts.moderation.slowmode.with_role_check") +    @mock.patch("bot.exts.moderation.slowmode.has_any_role")      @mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3)) -    def test_cog_check(self, role_check): +    async def test_cog_check(self, role_check):          """Role check is called with `MODERATION_ROLES`""" -        self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(self.ctx, *(1, 2, 3)) +        role_check.return_value.predicate = mock.AsyncMock() +        await self.cog.cog_check(self.ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index de72e5748..883465e0b 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,48 +1,50 @@  import unittest  from unittest.mock import MagicMock +from discord import DMChannel +  from bot.utils import checks  from bot.utils.checks import InWhitelistCheckFailure  from tests.helpers import MockContext, MockRole -class ChecksTests(unittest.TestCase): +class ChecksTests(unittest.IsolatedAsyncioTestCase):      """Tests the check functions defined in `bot.checks`."""      def setUp(self):          self.ctx = MockContext() -    def test_with_role_check_without_guild(self): -        """`with_role_check` returns `False` if `Context.guild` is None.""" -        self.ctx.guild = None -        self.assertFalse(checks.with_role_check(self.ctx)) +    async def test_has_any_role_check_without_guild(self): +        """`has_any_role_check` returns `False` for non-guild channels.""" +        self.ctx.channel = MagicMock(DMChannel) +        self.assertFalse(await checks.has_any_role_check(self.ctx)) -    def test_with_role_check_without_required_roles(self): -        """`with_role_check` returns `False` if `Context.author` lacks the required role.""" +    async def test_has_any_role_check_without_required_roles(self): +        """`has_any_role_check` returns `False` if `Context.author` lacks the required role."""          self.ctx.author.roles = [] -        self.assertFalse(checks.with_role_check(self.ctx)) +        self.assertFalse(await checks.has_any_role_check(self.ctx)) -    def test_with_role_check_with_guild_and_required_role(self): -        """`with_role_check` returns `True` if `Context.author` has the required role.""" +    async def test_has_any_role_check_with_guild_and_required_role(self): +        """`has_any_role_check` returns `True` if `Context.author` has the required role."""          self.ctx.author.roles.append(MockRole(id=10)) -        self.assertTrue(checks.with_role_check(self.ctx, 10)) +        self.assertTrue(await checks.has_any_role_check(self.ctx, 10)) -    def test_without_role_check_without_guild(self): -        """`without_role_check` should return `False` when `Context.guild` is None.""" -        self.ctx.guild = None -        self.assertFalse(checks.without_role_check(self.ctx)) +    async def test_has_no_roles_check_without_guild(self): +        """`has_no_roles_check` should return `False` when `Context.guild` is None.""" +        self.ctx.channel = MagicMock(DMChannel) +        self.assertFalse(await checks.has_no_roles_check(self.ctx)) -    def test_without_role_check_returns_false_with_unwanted_role(self): -        """`without_role_check` returns `False` if `Context.author` has unwanted role.""" +    async def test_has_no_roles_check_returns_false_with_unwanted_role(self): +        """`has_no_roles_check` returns `False` if `Context.author` has unwanted role."""          role_id = 42          self.ctx.author.roles.append(MockRole(id=role_id)) -        self.assertFalse(checks.without_role_check(self.ctx, role_id)) +        self.assertFalse(await checks.has_no_roles_check(self.ctx, role_id)) -    def test_without_role_check_returns_true_without_unwanted_role(self): -        """`without_role_check` returns `True` if `Context.author` does not have unwanted role.""" +    async def test_has_no_roles_check_returns_true_without_unwanted_role(self): +        """`has_no_roles_check` returns `True` if `Context.author` does not have unwanted role."""          role_id = 42          self.ctx.author.roles.append(MockRole(id=role_id)) -        self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) +        self.assertTrue(await checks.has_no_roles_check(self.ctx, role_id + 10))      def test_in_whitelist_check_correct_channel(self):          """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" | 
