diff options
| -rw-r--r-- | bot/converters.py | 51 | ||||
| -rw-r--r-- | bot/exts/info/information.py | 20 | ||||
| -rw-r--r-- | bot/exts/moderation/dm_relay.py | 6 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 30 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 4 | ||||
| -rw-r--r-- | bot/exts/utils/reminders.py | 4 | 
6 files changed, 78 insertions, 37 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 0118cc48a..bd4044c7e 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -11,7 +11,7 @@ import dateutil.tz  import discord  from aiohttp import ClientConnectorError  from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter +from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter  from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time  from bot import exts @@ -495,22 +495,51 @@ class HushDurationConverter(Converter):          return duration -class UserMentionOrID(UserConverter): +def _is_an_unambiguous_user_argument(argument: str) -> bool: +    """Check if the provided argument is a user mention, user id, or username (name#discrim).""" +    has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument)) + +    # Check to see if the author passed a username (a discriminator exists) +    argument = argument.removeprefix('@') +    has_username = len(argument) > 5 and argument[-5] == '#' + +    return has_id_or_mention or has_username + + +AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format" +                          " `name#discriminator`.") + + +class UnambiguousUser(UserConverter):      """ -    Converts to a `discord.User`, but only if a mention or userID is provided. +    Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided. -    Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim. -    This is useful in cases where that lookup strategy would lead to ambiguity. +    Unlike the default `UserConverter`, it doesn't allow conversion from a name. +    This is useful in cases where that lookup strategy would lead to too much ambiguity.      """      async def convert(self, ctx: Context, argument: str) -> discord.User: -        """Convert the `arg` to a `discord.User`.""" -        match = self._get_id_match(argument) or RE_USER_MENTION.match(argument) +        """Convert the `argument` to a `discord.User`.""" +        if _is_an_unambiguous_user_argument(argument): +            return await super().convert(ctx, argument) +        else: +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument)) + + +class UnambiguousMember(MemberConverter): +    """ +    Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided. + +    Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname. +    This is useful in cases where that lookup strategy would lead to too much ambiguity. +    """ -        if match is not None: +    async def convert(self, ctx: Context, argument: str) -> discord.Member: +        """Convert the `argument` to a `discord.Member`.""" +        if _is_an_unambiguous_user_argument(argument):              return await super().convert(ctx, argument)          else: -            raise BadArgument(f"`{argument}` is not a User mention or a User ID.") +            raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))  class Infraction(Converter): @@ -557,8 +586,10 @@ if t.TYPE_CHECKING:      OffTopicName = str  # noqa: F811      ISODateTime = datetime  # noqa: F811      HushDurationConverter = int  # noqa: F811 -    UserMentionOrID = discord.User  # noqa: F811 +    UnambiguousUser = discord.User  # noqa: F811 +    UnambiguousMember = discord.Member  # noqa: F811      Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime]  MemberOrUser = t.Union[discord.Member, discord.User] +UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index ae547b1b8..bcf8c10d2 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -460,11 +460,12 @@ class Information(Cog):          # remove trailing whitespace          return out.rstrip() -    @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_PARTNERS_COMMUNITY_ROLES) -    @group(invoke_without_command=True) -    @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_PARTNERS_COMMUNITY_ROLES) -    async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: -        """Shows information about the raw API response.""" +    async def send_raw_content(self, ctx: Context, message: Message, json: bool = False) -> None: +        """ +        Send information about the raw API response for a `discord.Message`. + +        If `json` is True, send the information in a copy-pasteable Python format. +        """          if ctx.author not in message.channel.members:              await ctx.send(":x: You do not have permissions to see the channel this message is in.")              return @@ -500,10 +501,17 @@ class Information(Cog):          for page in paginator.pages:              await ctx.send(page, allowed_mentions=AllowedMentions.none()) +    @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_PARTNERS_COMMUNITY_ROLES) +    @group(invoke_without_command=True) +    @in_whitelist(channels=(constants.Channels.bot_commands,), roles=constants.STAFF_PARTNERS_COMMUNITY_ROLES) +    async def raw(self, ctx: Context, message: Message) -> None: +        """Shows information about the raw API response.""" +        await self.send_raw_content(ctx, message) +      @raw.command()      async def json(self, ctx: Context, message: Message) -> None:          """Shows information about the raw API response in a copy-pasteable Python format.""" -        await ctx.invoke(self.raw, message=message, json=True) +        await self.send_raw_content(ctx, message, json=True)  def setup(bot: Bot) -> None: diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 1d2206e27..0051db82f 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -5,6 +5,7 @@ from discord.ext.commands import Cog, Context, command, has_any_role  from bot.bot import Bot  from bot.constants import Emojis, MODERATION_ROLES +from bot.utils.channel import is_mod_channel  from bot.utils.services import send_to_paste_service  log = logging.getLogger(__name__) @@ -63,8 +64,9 @@ class DMRelay(Cog):          await ctx.send(paste_link)      async def cog_check(self, ctx: Context) -> bool: -        """Only allow moderators to invoke the commands in this cog.""" -        return await has_any_role(*MODERATION_ROLES).predicate(ctx) +        """Only allow moderators to invoke the commands in this cog in mod channels.""" +        return (await has_any_role(*MODERATION_ROLES).predicate(ctx) +                and is_mod_channel(ctx.channel))  def setup(bot: Bot) -> None: diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 2f9083c29..eaba97703 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.bot import Bot  from bot.constants import Event -from bot.converters import Duration, Expiry, MemberOrUser +from bot.converters import Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser  from bot.decorators import respect_role_hierarchy  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler @@ -53,7 +53,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent infractions      @command() -    async def warn(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None: +    async def warn(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Warn a user for the given reason."""          if not isinstance(user, Member):              await ctx.send(":x: The user doesn't appear to be on the server.") @@ -66,7 +66,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command() -    async def kick(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None: +    async def kick(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Kick a user for the given reason."""          if not isinstance(user, Member):              await ctx.send(":x: The user doesn't appear to be on the server.") @@ -78,7 +78,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def ban(          self,          ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -94,7 +94,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def purgeban(          self,          ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -110,7 +110,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def voiceban(          self,          ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] @@ -128,7 +128,7 @@ class Infractions(InfractionScheduler, commands.Cog):      @command(aliases=["mute"])      async def tempmute(          self, ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -162,7 +162,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempban(          self,          ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: Expiry,          *,          reason: t.Optional[str] = None @@ -188,7 +188,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempvoiceban(              self,              ctx: Context, -            user: MemberOrUser, +            user: UnambiguousMemberOrUser,              duration: Expiry,              *,              reason: t.Optional[str] @@ -214,7 +214,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent shadow infractions      @command(hidden=True) -    async def note(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None: +    async def note(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Create a private note for a user with the given reason without notifying the user."""          infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False)          if infraction is None: @@ -223,7 +223,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command(hidden=True, aliases=['shadowban', 'sban']) -    async def shadow_ban(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None: +    async def shadow_ban(self, ctx: Context, user: UnambiguousMemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Permanently ban a user for the given reason without notifying the user."""          await self.apply_ban(ctx, user, reason, hidden=True) @@ -234,7 +234,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def shadow_tempban(          self,          ctx: Context, -        user: MemberOrUser, +        user: UnambiguousMemberOrUser,          duration: Expiry,          *,          reason: t.Optional[str] = None @@ -260,17 +260,17 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Remove infractions (un- commands)      @command() -    async def unmute(self, ctx: Context, user: MemberOrUser) -> None: +    async def unmute(self, ctx: Context, user: UnambiguousMemberOrUser) -> None:          """Prematurely end the active mute infraction for the user."""          await self.pardon_infraction(ctx, "mute", user)      @command() -    async def unban(self, ctx: Context, user: MemberOrUser) -> None: +    async def unban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None:          """Prematurely end the active ban infraction for the user."""          await self.pardon_infraction(ctx, "ban", user)      @command(aliases=("uvban",)) -    async def unvoiceban(self, ctx: Context, user: MemberOrUser) -> None: +    async def unvoiceban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None:          """Prematurely end the active voice ban infraction for the user."""          await self.pardon_infraction(ctx, "voice_ban", user) diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 641ad0410..7f27896d7 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UserMentionOrID, allowed_strings +from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser, allowed_strings  from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog  from bot.pagination import LinePaginator @@ -201,7 +201,7 @@ class ModManagement(commands.Cog):      # region: Search infractions      @infraction_group.group(name="search", aliases=('s',), invoke_without_command=True) -    async def infraction_search_group(self, ctx: Context, query: t.Union[UserMentionOrID, Snowflake, str]) -> None: +    async def infraction_search_group(self, ctx: Context, query: t.Union[UnambiguousUser, Snowflake, str]) -> None:          """Searches for infractions in the database."""          if isinstance(query, int):              await self.search_user(ctx, discord.Object(query)) diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 2bed5157f..41b6cac5c 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -15,7 +15,7 @@ from bot.constants import (      Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES,      Roles, STAFF_PARTNERS_COMMUNITY_ROLES  ) -from bot.converters import Duration, UserMentionOrID +from bot.converters import Duration, UnambiguousUser  from bot.pagination import LinePaginator  from bot.utils.checks import has_any_role_check, has_no_roles_check  from bot.utils.lock import lock_arg @@ -30,7 +30,7 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5  Mentionable = t.Union[discord.Member, discord.Role] -ReminderMention = t.Union[UserMentionOrID, discord.Role] +ReminderMention = t.Union[UnambiguousUser, discord.Role]  class Reminders(Cog): | 
