diff options
| author | 2021-08-23 21:29:24 +0200 | |
|---|---|---|
| committer | 2021-08-23 21:29:24 +0200 | |
| commit | 89db6237ef32da5addc77b8a138009974feb49b1 (patch) | |
| tree | 1546df7325d4e3a33efe59b067bb80e68831cb89 | |
| parent | remove redundant typehints (diff) | |
| parent | Merge pull request #1772 from D0rs4n/pr/nostafflake (diff) | |
Merge remote-tracking branch 'upstream/main' into converter-typehints
# Conflicts:
#	bot/converters.py
44 files changed, 1297 insertions, 656 deletions
| diff --git a/bot/constants.py b/bot/constants.py index 500803f33..407646b28 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -568,13 +568,15 @@ class Metabase(metaclass=YAMLGetter):      username: Optional[str]      password: Optional[str] -    url: str +    base_url: str      max_session_age: int  class AntiSpam(metaclass=YAMLGetter):      section = 'anti_spam' +    cache_size: int +      clean_offending: bool      ping_everyone: bool diff --git a/bot/converters.py b/bot/converters.py index 566e56220..1c0fd673d 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -4,7 +4,6 @@ import logging  import re  import typing as t  from datetime import datetime -from functools import partial  from ssl import CertificateError  import dateutil.parser @@ -515,29 +514,6 @@ class HushDurationConverter(Converter):          return duration -def proxy_user(user_id: str) -> discord.Object: -    """ -    Create a proxy user object from the given id. - -    Used when a Member or User object cannot be resolved. -    """ -    log.trace(f"Attempting to create a proxy user for the user id {user_id}.") - -    try: -        user_id = int(user_id) -    except ValueError: -        log.debug(f"Failed to create proxy user {user_id}: could not convert to int.") -        raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.") - -    user = discord.Object(user_id) -    user.mention = user.id -    user.display_name = f"<@{user.id}>" -    user.avatar_url_as = lambda static_format: None -    user.bot = False - -    return user - -  class UserMentionOrID(UserConverter):      """      Converts to a `discord.User`, but only if a mention or userID is provided. @@ -556,64 +532,6 @@ class UserMentionOrID(UserConverter):              raise BadArgument(f"`{argument}` is not a User mention or a User ID.") -class FetchedUser(UserConverter): -    """ -    Converts to a `discord.User` or, if it fails, a `discord.Object`. - -    Unlike the default `UserConverter`, which only does lookups via the global user cache, this -    converter attempts to fetch the user via an API call to Discord when the using the cache is -    unsuccessful. - -    If the fetch also fails and the error doesn't imply the user doesn't exist, then a -    `discord.Object` is returned via the `user_proxy` converter. - -    The lookup strategy is as follows (in order): - -    1. Lookup by ID. -    2. Lookup by mention. -    3. Lookup by name#discrim -    4. Lookup by name -    5. Lookup via API -    6. Create a proxy user with discord.Object -    """ - -    async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]: -        """Convert the `arg` to a `discord.User` or `discord.Object`.""" -        try: -            return await super().convert(ctx, arg) -        except BadArgument: -            pass - -        try: -            user_id = int(arg) -            log.trace(f"Fetching user {user_id}...") -            return await ctx.bot.fetch_user(user_id) -        except ValueError: -            log.debug(f"Failed to fetch user {arg}: could not convert to int.") -            raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`") -        except discord.HTTPException as e: -            # If the Discord error isn't `Unknown user`, return a proxy instead -            if e.code != 10013: -                log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}") -                return proxy_user(arg) - -            log.debug(f"Failed to fetch user {arg}: user does not exist.") -            raise BadArgument(f"User `{arg}` does not exist") - - -def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int: -    """ -    Extract the snowflake from `arg` using a regex `pattern` and return it as an int. - -    The snowflake is expected to be within the first capture group in `pattern`. -    """ -    match = pattern.match(arg) -    if not match: -        raise BadArgument(f"Mention {arg!r} is invalid.") - -    return int(match.group(1)) - -  class Infraction(Converter):      """      Attempts to convert a given infraction ID into an infraction. @@ -660,9 +578,7 @@ if t.TYPE_CHECKING:      ISODateTime = datetime  # noqa: F811      HushDurationConverter = int  # noqa: F811      UserMentionOrID = discord.User  # noqa: F811 -    FetchedUser = t.Union[discord.User, discord.Object]  # noqa: F811      Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime] -FetchedMember = t.Union[discord.Member, FetchedUser] -UserMention = partial(_snowflake_from_regex, RE_USER_MENTION) +MemberOrUser = t.Union[discord.Member, discord.User] diff --git a/bot/errors.py b/bot/errors.py index 5785faa44..08396ec3e 100644 --- a/bot/errors.py +++ b/bot/errors.py @@ -1,6 +1,6 @@ -from typing import Hashable, Union +from typing import Hashable -from discord import Member, User +from bot.converters import MemberOrUser  class LockedResourceError(RuntimeError): @@ -30,7 +30,8 @@ class InvalidInfractedUserError(Exception):          `user` -- User or Member which is invalid      """ -    def __init__(self, user: Union[Member, User], reason: str = "User infracted is a bot."): +    def __init__(self, user: MemberOrUser, reason: str = "User infracted is a bot."): +          self.user = user          self.reason = reason diff --git a/bot/exts/events/__init__.py b/bot/exts/events/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/bot/exts/events/__init__.py diff --git a/bot/exts/events/code_jams/__init__.py b/bot/exts/events/code_jams/__init__.py new file mode 100644 index 000000000..16e81e365 --- /dev/null +++ b/bot/exts/events/code_jams/__init__.py @@ -0,0 +1,8 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: +    """Load the CodeJams cog.""" +    from bot.exts.events.code_jams._cog import CodeJams + +    bot.add_cog(CodeJams(bot)) diff --git a/bot/exts/events/code_jams/_channels.py b/bot/exts/events/code_jams/_channels.py new file mode 100644 index 000000000..34ff0ad41 --- /dev/null +++ b/bot/exts/events/code_jams/_channels.py @@ -0,0 +1,113 @@ +import logging +import typing as t + +import discord + +from bot.constants import Categories, Channels, Roles + +log = logging.getLogger(__name__) + +MAX_CHANNELS = 50 +CATEGORY_NAME = "Code Jam" + + +async def _get_category(guild: discord.Guild) -> discord.CategoryChannel: +    """ +    Return a code jam category. + +    If all categories are full or none exist, create a new category. +    """ +    for category in guild.categories: +        if category.name == CATEGORY_NAME and len(category.channels) < MAX_CHANNELS: +            return category + +    return await _create_category(guild) + + +async def _create_category(guild: discord.Guild) -> discord.CategoryChannel: +    """Create a new code jam category and return it.""" +    log.info("Creating a new code jam category.") + +    category_overwrites = { +        guild.default_role: discord.PermissionOverwrite(read_messages=False), +        guild.me: discord.PermissionOverwrite(read_messages=True) +    } + +    category = await guild.create_category_channel( +        CATEGORY_NAME, +        overwrites=category_overwrites, +        reason="It's code jam time!" +    ) + +    await _send_status_update( +        guild, f"Created a new category with the ID {category.id} for this Code Jam's team channels." +    ) + +    return category + + +def _get_overwrites( +        members: list[tuple[discord.Member, bool]], +        guild: discord.Guild, +) -> dict[t.Union[discord.Member, discord.Role], discord.PermissionOverwrite]: +    """Get code jam team channels permission overwrites.""" +    team_channel_overwrites = { +        guild.default_role: discord.PermissionOverwrite(read_messages=False), +        guild.get_role(Roles.code_jam_event_team): discord.PermissionOverwrite(read_messages=True) +    } + +    for member, _ in members: +        team_channel_overwrites[member] = discord.PermissionOverwrite( +            read_messages=True +        ) + +    return team_channel_overwrites + + +async def create_team_channel( +        guild: discord.Guild, +        team_name: str, +        members: list[tuple[discord.Member, bool]], +        team_leaders: discord.Role +) -> None: +    """Create the team's text channel.""" +    await _add_team_leader_roles(members, team_leaders) + +    # Get permission overwrites and category +    team_channel_overwrites = _get_overwrites(members, guild) +    code_jam_category = await _get_category(guild) + +    # Create a text channel for the team +    await code_jam_category.create_text_channel( +        team_name, +        overwrites=team_channel_overwrites, +    ) + + +async def create_team_leader_channel(guild: discord.Guild, team_leaders: discord.Role) -> None: +    """Create the Team Leader Chat channel for the Code Jam team leaders.""" +    category: discord.CategoryChannel = guild.get_channel(Categories.summer_code_jam) + +    team_leaders_chat = await category.create_text_channel( +        name="team-leaders-chat", +        overwrites={ +            guild.default_role: discord.PermissionOverwrite(read_messages=False), +            team_leaders: discord.PermissionOverwrite(read_messages=True) +        } +    ) + +    await _send_status_update(guild, f"Created {team_leaders_chat.mention} in the {category} category.") + + +async def _send_status_update(guild: discord.Guild, message: str) -> None: +    """Inform the events lead with a status update when the command is ran.""" +    channel: discord.TextChannel = guild.get_channel(Channels.code_jam_planning) + +    await channel.send(f"<@&{Roles.events_lead}>\n\n{message}") + + +async def _add_team_leader_roles(members: list[tuple[discord.Member, bool]], team_leaders: discord.Role) -> None: +    """Assign the team leader role to the team leaders.""" +    for member, is_leader in members: +        if is_leader: +            await member.add_roles(team_leaders) diff --git a/bot/exts/events/code_jams/_cog.py b/bot/exts/events/code_jams/_cog.py new file mode 100644 index 000000000..e099f7dfa --- /dev/null +++ b/bot/exts/events/code_jams/_cog.py @@ -0,0 +1,235 @@ +import asyncio +import csv +import logging +import typing as t +from collections import defaultdict + +import discord +from discord import Colour, Embed, Guild, Member +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import Emojis, Roles +from bot.exts.events.code_jams import _channels +from bot.utils.services import send_to_paste_service + +log = logging.getLogger(__name__) + +TEAM_LEADERS_COLOUR = 0x11806a +DELETION_REACTION = "\U0001f4a5" + + +class CodeJams(commands.Cog): +    """Manages the code-jam related parts of our server.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @commands.group(aliases=("cj", "jam")) +    @commands.has_any_role(Roles.admins) +    async def codejam(self, ctx: commands.Context) -> None: +        """A Group of commands for managing Code Jams.""" +        if ctx.invoked_subcommand is None: +            await ctx.send_help(ctx.command) + +    @codejam.command() +    async def create(self, ctx: commands.Context, csv_file: t.Optional[str] = None) -> None: +        """ +        Create code-jam teams from a CSV file or a link to one, specifying the team names, leaders and members. + +        The CSV file must have 3 columns: 'Team Name', 'Team Member Discord ID', and 'Team Leader'. + +        This will create the text channels for the teams, and give the team leaders their roles. +        """ +        async with ctx.typing(): +            if csv_file: +                async with self.bot.http_session.get(csv_file) as response: +                    if response.status != 200: +                        await ctx.send(f"Got a bad response from the URL: {response.status}") +                        return + +                    csv_file = await response.text() + +            elif ctx.message.attachments: +                csv_file = (await ctx.message.attachments[0].read()).decode("utf8") +            else: +                raise commands.BadArgument("You must include either a CSV file or a link to one.") + +            teams = defaultdict(list) +            reader = csv.DictReader(csv_file.splitlines()) + +            for row in reader: +                member = ctx.guild.get_member(int(row["Team Member Discord ID"])) + +                if member is None: +                    log.trace(f"Got an invalid member ID: {row['Team Member Discord ID']}") +                    continue + +                teams[row["Team Name"]].append((member, row["Team Leader"].upper() == "Y")) + +            team_leaders = await ctx.guild.create_role(name="Code Jam Team Leaders", colour=TEAM_LEADERS_COLOUR) + +            for team_name, members in teams.items(): +                await _channels.create_team_channel(ctx.guild, team_name, members, team_leaders) + +            await _channels.create_team_leader_channel(ctx.guild, team_leaders) +            await ctx.send(f"{Emojis.check_mark} Created Code Jam with {len(teams)} teams.") + +    @codejam.command() +    @commands.has_any_role(Roles.admins) +    async def end(self, ctx: commands.Context) -> None: +        """ +        Delete all code jam channels. + +        A confirmation message is displayed with the categories and channels to be deleted.. Pressing the added reaction +        deletes those channels. +        """ +        def predicate_deletion_emoji_reaction(reaction: discord.Reaction, user: discord.User) -> bool: +            """Return True if the reaction :boom: was added by the context message author on this message.""" +            return ( +                reaction.message.id == message.id +                and user.id == ctx.author.id +                and str(reaction) == DELETION_REACTION +            ) + +        # A copy of the list of channels is stored. This is to make sure that we delete precisely the channels displayed +        # in the confirmation message. +        categories = self.jam_categories(ctx.guild) +        category_channels = {category: category.channels.copy() for category in categories} + +        confirmation_message = await self._build_confirmation_message(category_channels) +        message = await ctx.send(confirmation_message) +        await message.add_reaction(DELETION_REACTION) +        try: +            await self.bot.wait_for( +                'reaction_add', +                check=predicate_deletion_emoji_reaction, +                timeout=10 +            ) + +        except asyncio.TimeoutError: +            await message.clear_reaction(DELETION_REACTION) +            await ctx.send("Command timed out.", reference=message) +            return + +        else: +            await message.clear_reaction(DELETION_REACTION) +            for category, channels in category_channels.items(): +                for channel in channels: +                    await channel.delete(reason="Code jam ended.") +                await category.delete(reason="Code jam ended.") + +            await message.add_reaction(Emojis.check_mark) + +    @staticmethod +    async def _build_confirmation_message( +        categories: dict[discord.CategoryChannel, list[discord.abc.GuildChannel]] +    ) -> str: +        """Sends details of the channels to be deleted to the pasting service, and formats the confirmation message.""" +        def channel_repr(channel: discord.abc.GuildChannel) -> str: +            """Formats the channel name and ID and a readable format.""" +            return f"{channel.name} ({channel.id})" + +        def format_category_info(category: discord.CategoryChannel, channels: list[discord.abc.GuildChannel]) -> str: +            """Displays the category and the channels within it in a readable format.""" +            return f"{channel_repr(category)}:\n" + "\n".join("  - " + channel_repr(channel) for channel in channels) + +        deletion_details = "\n\n".join( +            format_category_info(category, channels) for category, channels in categories.items() +        ) + +        url = await send_to_paste_service(deletion_details) +        if url is None: +            url = "**Unable to send deletion details to the pasting service.**" + +        return f"Are you sure you want to delete all code jam channels?\n\nThe channels to be deleted: {url}" + +    @codejam.command() +    @commands.has_any_role(Roles.admins, Roles.code_jam_event_team) +    async def info(self, ctx: commands.Context, member: Member) -> None: +        """ +        Send an info embed about the member with the team they're in. + +        The team is found by searching the permissions of the team channels. +        """ +        channel = self.team_channel(ctx.guild, member) +        if not channel: +            await ctx.send(":x: I can't find the team channel for this member.") +            return + +        embed = Embed( +            title=str(member), +            colour=Colour.blurple() +        ) +        embed.add_field(name="Team", value=self.team_name(channel), inline=True) + +        await ctx.send(embed=embed) + +    @codejam.command() +    @commands.has_any_role(Roles.admins) +    async def move(self, ctx: commands.Context, member: Member, new_team_name: str) -> None: +        """Move participant from one team to another by changing the user's permissions for the relevant channels.""" +        old_team_channel = self.team_channel(ctx.guild, member) +        if not old_team_channel: +            await ctx.send(":x: I can't find the team channel for this member.") +            return + +        if old_team_channel.name == new_team_name or self.team_name(old_team_channel) == new_team_name: +            await ctx.send(f"`{member}` is already in `{new_team_name}`.") +            return + +        new_team_channel = self.team_channel(ctx.guild, new_team_name) +        if not new_team_channel: +            await ctx.send(f":x: I can't find a team channel named `{new_team_name}`.") +            return + +        await old_team_channel.set_permissions(member, overwrite=None, reason=f"Participant moved to {new_team_name}") +        await new_team_channel.set_permissions( +            member, +            overwrite=discord.PermissionOverwrite(read_messages=True), +            reason=f"Participant moved from {old_team_channel.name}" +        ) + +        await ctx.send( +            f"Participant moved from `{self.team_name(old_team_channel)}` to `{self.team_name(new_team_channel)}`." +        ) + +    @codejam.command() +    @commands.has_any_role(Roles.admins) +    async def remove(self, ctx: commands.Context, member: Member) -> None: +        """Remove the participant from their team. Does not remove the participants or leader roles.""" +        channel = self.team_channel(ctx.guild, member) +        if not channel: +            await ctx.send(":x: I can't find the team channel for this member.") +            return + +        await channel.set_permissions( +            member, +            overwrite=None, +            reason=f"Participant removed from the team  {self.team_name(channel)}." +        ) +        await ctx.send(f"Removed the participant from `{self.team_name(channel)}`.") + +    @staticmethod +    def jam_categories(guild: Guild) -> list[discord.CategoryChannel]: +        """Get all the code jam team categories.""" +        return [category for category in guild.categories if category.name == _channels.CATEGORY_NAME] + +    @staticmethod +    def team_channel(guild: Guild, criterion: t.Union[str, Member]) -> t.Optional[discord.TextChannel]: +        """Get a team channel through either a participant or the team name.""" +        for category in CodeJams.jam_categories(guild): +            for channel in category.channels: +                if isinstance(channel, discord.TextChannel): +                    if ( +                        # If it's a string. +                        criterion == channel.name or criterion == CodeJams.team_name(channel) +                        # If it's a member. +                        or criterion in channel.overwrites +                    ): +                        return channel + +    @staticmethod +    def team_name(channel: discord.TextChannel) -> str: +        """Retrieves the team name from the given channel.""" +        return channel.name.replace("-", " ").title() diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index 4c4836c88..0eedeb0fb 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -7,7 +7,7 @@ from discord.ext.commands import Cog  from bot.bot import Bot  from bot.constants import Channels, Filter, URLs -from bot.exts.utils.jams import CATEGORY_NAME as JAM_CATEGORY_NAME +from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME  log = logging.getLogger(__name__) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 3f891b2c6..987060779 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -1,8 +1,10 @@  import asyncio  import logging +from collections import defaultdict  from collections.abc import Mapping  from dataclasses import dataclass, field  from datetime import datetime, timedelta +from itertools import takewhile  from operator import attrgetter, itemgetter  from typing import Dict, Iterable, List, Set @@ -17,9 +19,10 @@ from bot.constants import (      Guild as GuildConfig, Icons,  )  from bot.converters import Duration +from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.exts.moderation.modlog import ModLog -from bot.exts.utils.jams import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.utils import lock, scheduling +from bot.utils.message_cache import MessageCache  from bot.utils.messages import format_user, send_attachments @@ -44,19 +47,18 @@ RULE_FUNCTION_MAPPING = {  class DeletionContext:      """Represents a Deletion Context for a single spam event.""" -    channel: TextChannel -    members: Dict[int, Member] = field(default_factory=dict) +    members: frozenset[Member] +    triggered_in: TextChannel +    channels: set[TextChannel] = field(default_factory=set)      rules: Set[str] = field(default_factory=set)      messages: Dict[int, Message] = field(default_factory=dict)      attachments: List[List[str]] = field(default_factory=list) -    async def add(self, rule_name: str, members: Iterable[Member], messages: Iterable[Message]) -> None: +    async def add(self, rule_name: str, channels: Iterable[TextChannel], messages: Iterable[Message]) -> None:          """Adds new rule violation events to the deletion context."""          self.rules.add(rule_name) -        for member in members: -            if member.id not in self.members: -                self.members[member.id] = member +        self.channels.update(channels)          for message in messages:              if message.id not in self.messages: @@ -69,11 +71,14 @@ class DeletionContext:      async def upload_messages(self, actor_id: int, modlog: ModLog) -> None:          """Method that takes care of uploading the queue and posting modlog alert.""" -        triggered_by_users = ", ".join(format_user(m) for m in self.members.values()) +        triggered_by_users = ", ".join(format_user(m) for m in self.members) +        triggered_in_channel = f"**Triggered in:** {self.triggered_in.mention}\n" if len(self.channels) > 1 else "" +        channels_description = ", ".join(channel.mention for channel in self.channels)          mod_alert_message = (              f"**Triggered by:** {triggered_by_users}\n" -            f"**Channel:** {self.channel.mention}\n" +            f"{triggered_in_channel}" +            f"**Channels:** {channels_description}\n"              f"**Rules:** {', '.join(rule for rule in self.rules)}\n"          ) @@ -116,6 +121,14 @@ class AntiSpam(Cog):          self.message_deletion_queue = dict() +        # Fetch the rule configuration with the highest rule interval. +        max_interval_config = max( +            AntiSpamConfig.rules.values(), +            key=itemgetter('interval') +        ) +        self.max_interval = max_interval_config['interval'] +        self.cache = MessageCache(AntiSpamConfig.cache_size, newest_first=True) +          self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")      @property @@ -155,19 +168,10 @@ class AntiSpam(Cog):          ):              return -        # Fetch the rule configuration with the highest rule interval. -        max_interval_config = max( -            AntiSpamConfig.rules.values(), -            key=itemgetter('interval') -        ) -        max_interval = max_interval_config['interval'] +        self.cache.append(message) -        # Store history messages since `interval` seconds ago in a list to prevent unnecessary API calls. -        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=max_interval) -        relevant_messages = [ -            msg async for msg in message.channel.history(after=earliest_relevant_at, oldest_first=False) -            if not msg.author.bot -        ] +        earliest_relevant_at = datetime.utcnow() - timedelta(seconds=self.max_interval) +        relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.cache))          for rule_name in AntiSpamConfig.rules:              rule_config = AntiSpamConfig.rules[rule_name] @@ -175,9 +179,10 @@ class AntiSpam(Cog):              # Create a list of messages that were sent in the interval that the rule cares about.              latest_interesting_stamp = datetime.utcnow() - timedelta(seconds=rule_config['interval']) -            messages_for_rule = [ -                msg for msg in relevant_messages if msg.created_at > latest_interesting_stamp -            ] +            messages_for_rule = list( +                takewhile(lambda msg: msg.created_at > latest_interesting_stamp, relevant_messages) +            ) +              result = await rule_function(message, messages_for_rule, rule_config)              # If the rule returns `None`, that means the message didn't violate it. @@ -190,19 +195,19 @@ class AntiSpam(Cog):                  full_reason = f"`{rule_name}` rule: {reason}"                  # If there's no spam event going on for this channel, start a new Message Deletion Context -                channel = message.channel -                if channel.id not in self.message_deletion_queue: -                    log.trace(f"Creating queue for channel `{channel.id}`") -                    self.message_deletion_queue[message.channel.id] = DeletionContext(channel) +                authors_set = frozenset(members) +                if authors_set not in self.message_deletion_queue: +                    log.trace(f"Creating queue for members `{authors_set}`") +                    self.message_deletion_queue[authors_set] = DeletionContext(authors_set, message.channel)                      scheduling.create_task( -                        self._process_deletion_context(message.channel.id), -                        name=f"AntiSpam._process_deletion_context({message.channel.id})" +                        self._process_deletion_context(authors_set), +                        name=f"AntiSpam._process_deletion_context({authors_set})"                      )                  # Add the relevant of this trigger to the Deletion Context -                await self.message_deletion_queue[message.channel.id].add( +                await self.message_deletion_queue[authors_set].add(                      rule_name=rule_name, -                    members=members, +                    channels=set(message.channel for message in messages_for_rule),                      messages=relevant_messages                  ) @@ -212,7 +217,7 @@ class AntiSpam(Cog):                          name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"                      ) -                await self.maybe_delete_messages(channel, relevant_messages) +                await self.maybe_delete_messages(messages_for_rule)                  break      @lock.lock_arg("antispam.punish", "member", attrgetter("id")) @@ -234,14 +239,18 @@ class AntiSpam(Cog):                  reason=reason              ) -    async def maybe_delete_messages(self, channel: TextChannel, messages: List[Message]) -> None: +    async def maybe_delete_messages(self, messages: List[Message]) -> None:          """Cleans the messages if cleaning is configured."""          if AntiSpamConfig.clean_offending:              # If we have more than one message, we can use bulk delete.              if len(messages) > 1:                  message_ids = [message.id for message in messages]                  self.mod_log.ignore(Event.message_delete, *message_ids) -                await channel.delete_messages(messages) +                channel_messages = defaultdict(list) +                for message in messages: +                    channel_messages[message.channel].append(message) +                for channel, messages in channel_messages.items(): +                    await channel.delete_messages(messages)              # Otherwise, the bulk delete endpoint will throw up.              # Delete the message directly instead. @@ -252,7 +261,7 @@ class AntiSpam(Cog):                  except NotFound:                      log.info(f"Tried to delete message `{messages[0].id}`, but message could not be found.") -    async def _process_deletion_context(self, context_id: int) -> None: +    async def _process_deletion_context(self, context_id: frozenset) -> None:          """Processes the Deletion Context queue."""          log.trace("Sleeping before processing message deletion queue.")          await asyncio.sleep(10) @@ -264,6 +273,11 @@ class AntiSpam(Cog):          deletion_context = self.message_deletion_queue.pop(context_id)          await deletion_context.upload_messages(self.bot.user.id, self.mod_log) +    @Cog.listener() +    async def on_message_edit(self, before: Message, after: Message) -> None: +        """Updates the message in the cache, if it's cached.""" +        self.cache.update(after) +  def validate_config(rules_: Mapping = AntiSpamConfig.rules) -> Dict[str, str]:      """Validates the antispam configs.""" diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 16aaf11cf..10cc7885d 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -19,8 +19,8 @@ from bot.constants import (      Channels, Colours, Filter,      Guild, Icons, URLs  ) +from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.exts.moderation.modlog import ModLog -from bot.exts.utils.jams import CATEGORY_NAME as JAM_CATEGORY_NAME  from bot.utils.messages import format_user  from bot.utils.regex import INVITE_RE  from bot.utils.scheduling import Scheduler diff --git a/bot/exts/filters/webhook_remover.py b/bot/exts/filters/webhook_remover.py index f11fc8912..25e267426 100644 --- a/bot/exts/filters/webhook_remover.py +++ b/bot/exts/filters/webhook_remover.py @@ -9,12 +9,15 @@ from bot.constants import Channels, Colours, Event, Icons  from bot.exts.moderation.modlog import ModLog  from bot.utils.messages import format_user -WEBHOOK_URL_RE = re.compile(r"((?:https?://)?discord(?:app)?\.com/api/webhooks/\d+/)\S+/?", re.IGNORECASE) +WEBHOOK_URL_RE = re.compile( +    r"((?:https?:\/\/)?(?:ptb\.|canary\.)?discord(?:app)?\.com\/api\/webhooks\/\d+\/)\S+\/?", +    re.IGNORECASE +)  ALERT_MESSAGE_TEMPLATE = (      "{user}, looks like you posted a Discord webhook URL. Therefore, your " -    "message has been removed. Your webhook may have been **compromised** so " -    "please re-create the webhook **immediately**. If you believe this was a " +    "message has been removed, and your webhook has been deleted. " +    "You can re-create it if you wish to. If you believe this was a "      "mistake, please let us know."  ) @@ -32,7 +35,7 @@ class WebhookRemover(Cog):          """Get current instance of `ModLog`."""          return self.bot.get_cog("ModLog") -    async def delete_and_respond(self, msg: Message, redacted_url: str) -> None: +    async def delete_and_respond(self, msg: Message, redacted_url: str, *, webhook_deleted: bool) -> None:          """Delete `msg` and send a warning that it contained the Discord webhook `redacted_url`."""          # Don't log this, due internal delete, not by user. Will make different entry.          self.mod_log.ignore(Event.message_delete, msg.id) @@ -44,9 +47,12 @@ class WebhookRemover(Cog):              return          await msg.channel.send(ALERT_MESSAGE_TEMPLATE.format(user=msg.author.mention)) - +        if webhook_deleted: +            delete_state = "The webhook was successfully deleted." +        else: +            delete_state = "There was an error when deleting the webhook, it might have already been removed."          message = ( -            f"{format_user(msg.author)} posted a Discord webhook URL to {msg.channel.mention}. " +            f"{format_user(msg.author)} posted a Discord webhook URL to {msg.channel.mention}. {delete_state} "              f"Webhook URL was `{redacted_url}`"          )          log.debug(message) @@ -72,7 +78,10 @@ class WebhookRemover(Cog):          matches = WEBHOOK_URL_RE.search(msg.content)          if matches: -            await self.delete_and_respond(msg, matches[1] + "xxx") +            async with self.bot.http_session.delete(matches[0]) as resp: +                # The Discord API Returns a 204 NO CONTENT response on success. +                deleted_successfully = resp.status == 204 +            await self.delete_and_respond(msg, matches[1] + "xxx", webhook_deleted=deleted_successfully)      @Cog.listener()      async def on_message_edit(self, before: Message, after: Message) -> None: diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index c78b9c141..7f7e4585c 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -3,11 +3,12 @@ import logging  from typing import Union  import discord -from discord import Color, Embed, Member, Message, RawReactionActionEvent, TextChannel, User, errors +from discord import Color, Embed, Message, RawReactionActionEvent, TextChannel, errors  from discord.ext.commands import Cog, Context, command  from bot import constants  from bot.bot import Bot +from bot.converters import MemberOrUser  from bot.utils.checks import has_any_role  from bot.utils.messages import count_unique_users_reaction, send_attachments  from bot.utils.webhooks import send_webhook @@ -36,7 +37,7 @@ class DuckPond(Cog):              log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`")      @staticmethod -    def is_staff(member: Union[User, Member]) -> bool: +    def is_staff(member: MemberOrUser) -> bool:          """Check if a specific member or user is staff."""          if hasattr(member, "roles"):              for role in member.roles: @@ -171,8 +172,14 @@ class DuckPond(Cog):          if not self.is_helper_viewable(channel):              return -        message = await channel.fetch_message(payload.message_id) +        try: +            message = await channel.fetch_message(payload.message_id) +        except discord.NotFound: +            return  # Message was deleted. +          member = discord.utils.get(message.guild.members, id=payload.user_id) +        if not member: +            return  # Member left or wasn't in the cache.          # Was the message sent by a human staff member?          if not self.is_staff(message.author) or message.author.bot: diff --git a/bot/exts/help_channels/_cog.py b/bot/exts/help_channels/_cog.py index 35658d117..cfc9cf477 100644 --- a/bot/exts/help_channels/_cog.py +++ b/bot/exts/help_channels/_cog.py @@ -267,6 +267,8 @@ class HelpChannels(commands.Cog):              for channel in channels[:abs(missing)]:                  await self.unclaim_channel(channel, closed_on=_channel.ClosingReason.CLEANUP) +        self.available_help_channels = set(_channel.get_category_channels(self.available_category)) +          # Getting channels that need to be included in the dynamic message.          await self.update_available_help_channels()          log.trace("Dynamic available help message updated.") @@ -387,7 +389,12 @@ class HelpChannels(commands.Cog):          )          log.trace(f"Sending dormant message for #{channel} ({channel.id}).") -        embed = discord.Embed(description=_message.DORMANT_MSG) +        embed = discord.Embed( +            description=_message.DORMANT_MSG.format( +                dormant=self.dormant_category.name, +                available=self.available_category.name, +            ) +        )          await channel.send(embed=embed)          log.trace(f"Pushing #{channel} ({channel.id}) into the channel queue.") @@ -511,11 +518,6 @@ class HelpChannels(commands.Cog):      async def update_available_help_channels(self) -> None:          """Updates the dynamic message within #how-to-get-help for available help channels.""" -        if not self.available_help_channels: -            self.available_help_channels = set( -                c for c in self.available_category.channels if not _channel.is_excluded_channel(c) -            ) -          available_channels = AVAILABLE_HELP_CHANNELS.format(              available=", ".join(                  c.mention for c in sorted(self.available_help_channels, key=attrgetter("position")) diff --git a/bot/exts/help_channels/_message.py b/bot/exts/help_channels/_message.py index befacd263..077b20b47 100644 --- a/bot/exts/help_channels/_message.py +++ b/bot/exts/help_channels/_message.py @@ -30,12 +30,12 @@ AVAILABLE_TITLE = "Available help channel"  AVAILABLE_FOOTER = "Closes after a period of inactivity, or when you send !close."  DORMANT_MSG = f""" -This help channel has been marked as **dormant**, and has been moved into the **Help: Dormant** \ +This help channel has been marked as **dormant**, and has been moved into the **{{dormant}}** \  category at the bottom of the channel list. It is no longer possible to send messages in this \  channel until it becomes available again.  If your question wasn't answered yet, you can claim a new help channel from the \ -**Help: Available** category by simply asking your question again. Consider rephrasing the \ +**{{available}}** category by simply asking your question again. Consider rephrasing the \  question to maximize your chance of getting a good answer. If you're not sure how, have a look \  through our guide for **[asking a good question]({ASKING_GUIDE_URL})**.  """ diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index 24a9ae28a..4a90a0668 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -4,8 +4,8 @@ import textwrap  from typing import Any  from urllib.parse import quote_plus +import discord  from aiohttp import ClientResponseError -from discord import Message  from discord.ext.commands import Cog  from bot.bot import Bot @@ -45,6 +45,17 @@ class CodeSnippets(Cog):      Matches each message against a regex and prints the contents of all matched snippets.      """ +    def __init__(self, bot: Bot): +        """Initializes the cog's bot.""" +        self.bot = bot + +        self.pattern_handlers = [ +            (GITHUB_RE, self._fetch_github_snippet), +            (GITHUB_GIST_RE, self._fetch_github_gist_snippet), +            (GITLAB_RE, self._fetch_gitlab_snippet), +            (BITBUCKET_RE, self._fetch_bitbucket_snippet) +        ] +      async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any:          """Makes http requests using aiohttp."""          async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response: @@ -208,56 +219,56 @@ class CodeSnippets(Cog):          # Returns an empty codeblock if the snippet is empty          return f'{ret}``` ```' -    def __init__(self, bot: Bot): -        """Initializes the cog's bot.""" -        self.bot = bot +    async def _parse_snippets(self, content: str) -> str: +        """Parse message content and return a string with a code block for each URL found.""" +        all_snippets = [] + +        for pattern, handler in self.pattern_handlers: +            for match in pattern.finditer(content): +                try: +                    snippet = await handler(**match.groupdict()) +                    all_snippets.append((match.start(), snippet)) +                except ClientResponseError as error: +                    error_message = error.message  # noqa: B306 +                    log.log( +                        logging.DEBUG if error.status == 404 else logging.ERROR, +                        f'Failed to fetch code snippet from {match[0]!r}: {error.status} ' +                        f'{error_message} for GET {error.request_info.real_url.human_repr()}' +                    ) -        self.pattern_handlers = [ -            (GITHUB_RE, self._fetch_github_snippet), -            (GITHUB_GIST_RE, self._fetch_github_gist_snippet), -            (GITLAB_RE, self._fetch_gitlab_snippet), -            (BITBUCKET_RE, self._fetch_bitbucket_snippet) -        ] +        # Sorts the list of snippets by their match index and joins them into a single message +        return '\n'.join(map(lambda x: x[1], sorted(all_snippets)))      @Cog.listener() -    async def on_message(self, message: Message) -> None: +    async def on_message(self, message: discord.Message) -> None:          """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" -        if not message.author.bot: -            all_snippets = [] - -            for pattern, handler in self.pattern_handlers: -                for match in pattern.finditer(message.content): -                    try: -                        snippet = await handler(**match.groupdict()) -                        all_snippets.append((match.start(), snippet)) -                    except ClientResponseError as error: -                        error_message = error.message  # noqa: B306 -                        log.log( -                            logging.DEBUG if error.status == 404 else logging.ERROR, -                            f'Failed to fetch code snippet from {match[0]!r}: {error.status} ' -                            f'{error_message} for GET {error.request_info.real_url.human_repr()}' -                        ) - -            # Sorts the list of snippets by their match index and joins them into a single message -            message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets))) - -            if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: +        if message.author.bot: +            return + +        message_to_send = await self._parse_snippets(message.content) +        destination = message.channel + +        if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: +            try:                  await message.edit(suppress=True) -                if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands: -                    # Redirects to #bot-commands if the snippet contents are too long -                    await self.bot.wait_until_guild_available() -                    await message.channel.send(('The snippet you tried to send was too long. Please ' -                                                f'see <#{Channels.bot_commands}> for the full snippet.')) -                    bot_commands_channel = self.bot.get_channel(Channels.bot_commands) -                    await wait_for_deletion( -                        await bot_commands_channel.send(message_to_send), -                        (message.author.id,) -                    ) -                else: -                    await wait_for_deletion( -                        await message.channel.send(message_to_send), -                        (message.author.id,) -                    ) +            except discord.NotFound: +                # Don't send snippets if the original message was deleted. +                return + +            if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands: +                # Redirects to #bot-commands if the snippet contents are too long +                await self.bot.wait_until_guild_available() +                destination = self.bot.get_channel(Channels.bot_commands) + +                await message.channel.send( +                    'The snippet you tried to send was too long. ' +                    f'Please see {destination.mention} for the full snippet.' +                ) + +            await wait_for_deletion( +                await destination.send(message_to_send), +                (message.author.id,) +            )  def setup(bot: Bot) -> None: diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index 9094d9d15..9a0705d2b 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -177,10 +177,13 @@ class CodeBlockCog(Cog, name="Code Block"):          if not bot_message:              return -        if not instructions: -            log.info("User's incorrect code block has been fixed. Removing instructions message.") -            await bot_message.delete() -            del self.codeblock_message_ids[payload.message_id] -        else: -            log.info("Message edited but still has invalid code blocks; editing the instructions.") -            await bot_message.edit(embed=self.create_embed(instructions)) +        try: +            if not instructions: +                log.info("User's incorrect code block was fixed. Removing instructions message.") +                await bot_message.delete() +                del self.codeblock_message_ids[payload.message_id] +            else: +                log.info("Message edited but still has invalid code blocks; editing instructions.") +                await bot_message.edit(embed=self.create_embed(instructions)) +        except discord.NotFound: +            log.debug("Could not find instructions message; it was probably deleted.") diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 167731e64..8bef6a8cd 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -8,11 +8,12 @@ from typing import Any, DefaultDict, Mapping, Optional, Tuple, Union  import rapidfuzz  from discord import AllowedMentions, Colour, Embed, Guild, Message, Role  from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role +from discord.utils import escape_markdown  from bot import constants  from bot.api import ResponseCodeError  from bot.bot import Bot -from bot.converters import FetchedMember +from bot.converters import MemberOrUser  from bot.decorators import in_whitelist  from bot.errors import NonExistentRoleError  from bot.pagination import LinePaginator @@ -186,21 +187,21 @@ class Information(Cog):          online_presences = py_invite.approximate_presence_count          offline_presences = py_invite.approximate_member_count - online_presences          member_status = ( -            f"{constants.Emojis.status_online} {online_presences} " -            f"{constants.Emojis.status_offline} {offline_presences}" +            f"{constants.Emojis.status_online} {online_presences:,} " +            f"{constants.Emojis.status_offline} {offline_presences:,}"          ) -        embed.description = textwrap.dedent(f""" -            Created: {created} -            Voice region: {region}\ -            {features} -            Roles: {num_roles} -            Member status: {member_status} -        """) +        embed.description = ( +            f"Created: {created}" +            f"\nVoice region: {region}" +            f"{features}" +            f"\nRoles: {num_roles}" +            f"\nMember status: {member_status}" +        )          embed.set_thumbnail(url=ctx.guild.icon_url)          # Members -        total_members = ctx.guild.member_count +        total_members = f"{ctx.guild.member_count:,}"          member_counts = self.get_member_counts(ctx.guild)          member_info = "\n".join(f"{role}: {count}" for role, count in member_counts.items())          embed.add_field(name=f"Members: {total_members}", value=member_info) @@ -220,7 +221,7 @@ class Information(Cog):          await ctx.send(embed=embed)      @command(name="user", aliases=["user_info", "member", "member_info", "u"]) -    async def user_info(self, ctx: Context, user: FetchedMember = None) -> None: +    async def user_info(self, ctx: Context, user: MemberOrUser = None) -> None:          """Returns info about a user."""          if user is None:              user = ctx.author @@ -235,7 +236,7 @@ class Information(Cog):              embed = await self.create_user_embed(ctx, user)              await ctx.send(embed=embed) -    async def create_user_embed(self, ctx: Context, user: FetchedMember) -> Embed: +    async def create_user_embed(self, ctx: Context, user: MemberOrUser) -> Embed:          """Creates an embed containing information on the `user`."""          on_server = bool(ctx.guild.get_member(user.id)) @@ -244,6 +245,7 @@ class Information(Cog):          name = str(user)          if on_server and user.nick:              name = f"{user.nick} ({name})" +        name = escape_markdown(name)          if user.public_flags.verified_bot:              name += f" {constants.Emojis.verified_bot}" @@ -257,7 +259,11 @@ class Information(Cog):                  badges.append(emoji)          if on_server: -            joined = discord_timestamp(user.joined_at, TimestampFormats.RELATIVE) +            if user.joined_at: +                joined = discord_timestamp(user.joined_at, TimestampFormats.RELATIVE) +            else: +                joined = "Unable to get join date" +              # The 0 is for excluding the default @everyone role,              # and the -1 is for reversing the order of the roles to highest to lowest in hierarchy.              roles = ", ".join(role.mention for role in user.roles[:0:-1]) @@ -307,7 +313,7 @@ class Information(Cog):          return embed -    async def basic_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]: +    async def basic_user_infraction_counts(self, user: MemberOrUser) -> Tuple[str, str]:          """Gets the total and active infraction counts for the given `member`."""          infractions = await self.bot.api_client.get(              'bot/infractions', @@ -324,7 +330,7 @@ class Information(Cog):          return "Infractions", infraction_output -    async def expanded_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]: +    async def expanded_user_infraction_counts(self, user: MemberOrUser) -> Tuple[str, str]:          """          Gets expanded infraction counts for the given `member`. @@ -365,7 +371,7 @@ class Information(Cog):          return "Infractions", "\n".join(infraction_output) -    async def user_nomination_counts(self, user: FetchedMember) -> Tuple[str, str]: +    async def user_nomination_counts(self, user: MemberOrUser) -> Tuple[str, str]:          """Gets the active and historical nomination counts for the given `member`."""          nominations = await self.bot.api_client.get(              'bot/nominations', @@ -390,7 +396,7 @@ class Information(Cog):          return "Nominations", "\n".join(output) -    async def user_messages(self, user: FetchedMember) -> Tuple[Union[bool, str], Tuple[str, str]]: +    async def user_messages(self, user: MemberOrUser) -> Tuple[Union[bool, str], Tuple[str, str]]:          """          Gets the amount of messages for `member`. diff --git a/bot/exts/info/pep.py b/bot/exts/info/pep.py index 8ac96bbdb..b11b34db0 100644 --- a/bot/exts/info/pep.py +++ b/bot/exts/info/pep.py @@ -9,7 +9,7 @@ from discord.ext.commands import Cog, Context, command  from bot.bot import Bot  from bot.constants import Keys -from bot.utils.cache import AsyncCache +from bot.utils.caching import AsyncCache  log = logging.getLogger(__name__) diff --git a/bot/exts/info/python_news.py b/bot/exts/info/python_news.py index a7837c93a..63eb4ac17 100644 --- a/bot/exts/info/python_news.py +++ b/bot/exts/info/python_news.py @@ -1,4 +1,5 @@  import logging +import re  import typing as t  from datetime import date, datetime @@ -72,6 +73,11 @@ class PythonNews(Cog):              if mail["name"].split("@")[0] in constants.PythonNews.mail_lists:                  self.webhook_names[mail["name"].split("@")[0]] = mail["display_name"] +    @staticmethod +    def escape_markdown(content: str) -> str: +        """Escape the markdown underlines and spoilers.""" +        return re.sub(r"[_|]", lambda match: "\\" + match[0], content) +      async def post_pep_news(self) -> None:          """Fetch new PEPs and when they don't have announcement in #python-news, create it."""          # Wait until everything is ready and http_session available @@ -103,7 +109,7 @@ class PythonNews(Cog):              # Build an embed and send a webhook              embed = discord.Embed(                  title=new["title"], -                description=new["summary"], +                description=self.escape_markdown(new["summary"]),                  timestamp=new_datetime,                  url=new["link"],                  colour=constants.Colours.soft_green @@ -167,7 +173,7 @@ class PythonNews(Cog):                  ):                      continue -                content = email_information["content"] +                content = self.escape_markdown(email_information["content"])                  link = THREAD_URL.format(id=thread["href"].split("/")[-2], list=maillist)                  # Build an embed and send a message to the webhook diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py index fb5b99086..28eb558a6 100644 --- a/bot/exts/info/site.py +++ b/bot/exts/info/site.py @@ -9,7 +9,7 @@ from bot.pagination import LinePaginator  log = logging.getLogger(__name__) -PAGES_URL = f"{URLs.site_schema}{URLs.site}/pages" +BASE_URL = f"{URLs.site_schema}{URLs.site}"  class Site(Cog): @@ -43,7 +43,7 @@ class Site(Cog):      @site_group.command(name="resources", root_aliases=("resources", "resource"))      async def site_resources(self, ctx: Context) -> None:          """Info about the site's Resources page.""" -        learning_url = f"{PAGES_URL}/resources" +        learning_url = f"{BASE_URL}/resources"          embed = Embed(title="Resources")          embed.set_footer(text=f"{learning_url}") @@ -59,7 +59,7 @@ class Site(Cog):      @site_group.command(name="tools", root_aliases=("tools",))      async def site_tools(self, ctx: Context) -> None:          """Info about the site's Tools page.""" -        tools_url = f"{PAGES_URL}/resources/tools" +        tools_url = f"{BASE_URL}/resources/tools"          embed = Embed(title="Tools")          embed.set_footer(text=f"{tools_url}") @@ -74,7 +74,7 @@ class Site(Cog):      @site_group.command(name="help")      async def site_help(self, ctx: Context) -> None:          """Info about the site's Getting Help page.""" -        url = f"{PAGES_URL}/resources/guides/asking-good-questions" +        url = f"{BASE_URL}/pages/guides/pydis-guides/asking-good-questions/"          embed = Embed(title="Asking Good Questions")          embed.set_footer(text=url) @@ -90,7 +90,7 @@ class Site(Cog):      @site_group.command(name="faq", root_aliases=("faq",))      async def site_faq(self, ctx: Context) -> None:          """Info about the site's FAQ page.""" -        url = f"{PAGES_URL}/frequently-asked-questions" +        url = f"{BASE_URL}/pages/frequently-asked-questions"          embed = Embed(title="FAQ")          embed.set_footer(text=url) @@ -107,13 +107,13 @@ class Site(Cog):      @site_group.command(name="rules", aliases=("r", "rule"), root_aliases=("rules", "rule"))      async def site_rules(self, ctx: Context, rules: Greedy[int]) -> None:          """Provides a link to all rules or, if specified, displays specific rule(s).""" -        rules_embed = Embed(title='Rules', color=Colour.blurple(), url=f'{PAGES_URL}/rules') +        rules_embed = Embed(title='Rules', color=Colour.blurple(), url=f'{BASE_URL}/pages/rules')          if not rules:              # Rules were not submitted. Return the default description.              rules_embed.description = (                  "The rules and guidelines that apply to this community can be found on" -                f" our [rules page]({PAGES_URL}/rules). We expect" +                f" our [rules page]({BASE_URL}/pages/rules). We expect"                  " all members of the community to have read and understood these."              ) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index 0e479d33f..561e0251e 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -143,7 +143,14 @@ async def add_signals(incident: discord.Message) -> None:              log.trace(f"Skipping emoji as it's already been placed: {signal_emoji}")          else:              log.trace(f"Adding reaction: {signal_emoji}") -            await incident.add_reaction(signal_emoji.value) +            try: +                await incident.add_reaction(signal_emoji.value) +            except discord.NotFound as e: +                if e.code != 10008: +                    raise + +                log.trace(f"Couldn't react with signal because message {incident.id} was deleted; skipping incident") +                return  class Incidents(Cog): @@ -288,14 +295,20 @@ class Incidents(Cog):          members_roles: t.Set[int] = {role.id for role in member.roles}          if not members_roles & ALLOWED_ROLES:  # Intersection is truthy on at least 1 common element              log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals") -            await incident.remove_reaction(reaction, member) +            try: +                await incident.remove_reaction(reaction, member) +            except discord.NotFound: +                log.trace("Couldn't remove reaction because the reaction or its message was deleted")              return          try:              signal = Signal(reaction)          except ValueError:              log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") -            await incident.remove_reaction(reaction, member) +            try: +                await incident.remove_reaction(reaction, member) +            except discord.NotFound: +                log.trace("Couldn't remove reaction because the reaction or its message was deleted")              return          log.trace(f"Received signal: {signal}") @@ -313,7 +326,10 @@ class Incidents(Cog):          confirmation_task = self.make_confirmation_task(incident, timeout)          log.trace("Deleting original message") -        await incident.delete() +        try: +            await incident.delete() +        except discord.NotFound: +            log.trace("Couldn't delete message because it was already deleted")          log.trace(f"Awaiting deletion confirmation: {timeout=} seconds")          try: diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 8286d3635..6ba4e74e9 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -13,8 +13,8 @@ from bot import constants  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Colours +from bot.converters import MemberOrUser  from bot.exts.moderation.infraction import _utils -from bot.exts.moderation.infraction._utils import UserSnowflake  from bot.exts.moderation.modlog import ModLog  from bot.utils import messages, scheduling, time  from bot.utils.channel import is_mod_channel @@ -115,7 +115,7 @@ class InfractionScheduler:          self,          ctx: Context,          infraction: _utils.Infraction, -        user: UserSnowflake, +        user: MemberOrUser,          action_coro: t.Optional[t.Awaitable] = None,          user_reason: t.Optional[str] = None,          additional_info: str = "", @@ -165,17 +165,10 @@ class InfractionScheduler:              dm_result = f"{constants.Emojis.failmail} "              dm_log_text = "\nDM: **Failed**" -            # Sometimes user is a discord.Object; make it a proper user. -            try: -                if not isinstance(user, (discord.Member, discord.User)): -                    user = await self.bot.fetch_user(user.id) -            except discord.HTTPException as e: -                log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})") -            else: -                # Accordingly display whether the user was successfully notified via DM. -                if await _utils.notify_infraction(user, infr_type.replace("_", " ").title(), expiry, user_reason, icon): -                    dm_result = ":incoming_envelope: " -                    dm_log_text = "\nDM: Sent" +            # Accordingly display whether the user was successfully notified via DM. +            if await _utils.notify_infraction(user, infr_type.replace("_", " ").title(), expiry, user_reason, icon): +                dm_result = ":incoming_envelope: " +                dm_log_text = "\nDM: Sent"          end_msg = ""          if infraction["actor"] == self.bot.user.id: @@ -264,14 +257,18 @@ class InfractionScheduler:              self,              ctx: Context,              infr_type: str, -            user: UserSnowflake, -            send_msg: bool = True +            user: MemberOrUser, +            *, +            send_msg: bool = True, +            notify: bool = True      ) -> None:          """          Prematurely end an infraction for a user and log the action in the mod log.          If `send_msg` is True, then a pardoning confirmation message will be sent to -        the context channel.  Otherwise, no such message will be sent. +        the context channel. Otherwise, no such message will be sent. + +        If `notify` is True, notify the user of the pardon via DM where applicable.          """          log.trace(f"Pardoning {infr_type} infraction for {user}.") @@ -292,7 +289,7 @@ class InfractionScheduler:              return          # Deactivate the infraction and cancel its scheduled expiration task. -        log_text = await self.deactivate_infraction(response[0], send_log=False) +        log_text = await self.deactivate_infraction(response[0], send_log=False, notify=notify)          log_text["Member"] = messages.format_user(user)          log_text["Actor"] = ctx.author.mention @@ -345,7 +342,9 @@ class InfractionScheduler:      async def deactivate_infraction(          self,          infraction: _utils.Infraction, -        send_log: bool = True +        *, +        send_log: bool = True, +        notify: bool = True      ) -> t.Dict[str, str]:          """          Deactivate an active infraction and return a dictionary of lines to send in a mod log. @@ -354,6 +353,8 @@ class InfractionScheduler:          expiration task cancelled. If `send_log` is True, a mod log is sent for the          deactivation of the infraction. +        If `notify` is True, notify the user of the pardon via DM where applicable. +          Infractions of unsupported types will raise a ValueError.          """          guild = self.bot.get_guild(constants.Guild.id) @@ -380,7 +381,7 @@ class InfractionScheduler:          try:              log.trace("Awaiting the pardon action coroutine.") -            returned_log = await self._pardon_action(infraction) +            returned_log = await self._pardon_action(infraction, notify)              if returned_log is not None:                  log_text = {**log_text, **returned_log}  # Merge the logs together @@ -468,10 +469,15 @@ class InfractionScheduler:          return log_text      @abstractmethod -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: +    async def _pardon_action( +        self, +        infraction: _utils.Infraction, +        notify: bool +    ) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. +        If `notify` is True, notify the user of the pardon via DM where applicable.          If an infraction type is unsupported, return None instead.          """          raise NotImplementedError diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index a4059a6e9..b20ef1d06 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -7,6 +7,7 @@ from discord.ext.commands import Context  from bot.api import ResponseCodeError  from bot.constants import Colours, Icons +from bot.converters import MemberOrUser  from bot.errors import InvalidInfractedUserError  log = logging.getLogger(__name__) @@ -24,8 +25,6 @@ INFRACTION_ICONS = {  RULES_URL = "https://pythondiscord.com/pages/rules"  # Type aliases -UserObject = t.Union[discord.Member, discord.User] -UserSnowflake = t.Union[UserObject, discord.Object]  Infraction = t.Dict[str, t.Union[str, int, bool]]  APPEAL_EMAIL = "[email protected]" @@ -45,7 +44,7 @@ INFRACTION_DESCRIPTION_TEMPLATE = (  ) -async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]: +async def post_user(ctx: Context, user: MemberOrUser) -> t.Optional[dict]:      """      Create a new user in the database. @@ -53,14 +52,11 @@ async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]:      """      log.trace(f"Attempting to add user {user.id} to the database.") -    if not isinstance(user, (discord.Member, discord.User)): -        log.debug("The user being added to the DB is not a Member or User object.") -      payload = { -        'discriminator': int(getattr(user, 'discriminator', 0)), +        'discriminator': int(user.discriminator),          'id': user.id,          'in_guild': False, -        'name': getattr(user, 'name', 'Name unknown'), +        'name': user.name,          'roles': []      } @@ -75,7 +71,7 @@ async def post_user(ctx: Context, user: UserSnowflake) -> t.Optional[dict]:  async def post_infraction(          ctx: Context, -        user: UserSnowflake, +        user: MemberOrUser,          infr_type: str,          reason: str,          expires_at: datetime = None, @@ -118,7 +114,7 @@ async def post_infraction(  async def get_active_infraction(          ctx: Context, -        user: UserSnowflake, +        user: MemberOrUser,          infr_type: str,          send_msg: bool = True  ) -> t.Optional[dict]: @@ -143,17 +139,22 @@ async def get_active_infraction(          # Checks to see if the moderator should be told there is an active infraction          if send_msg:              log.trace(f"{user} has active infractions of type {infr_type}.") -            await ctx.send( -                f":x: According to my records, this user already has a {infr_type} infraction. " -                f"See infraction **#{active_infractions[0]['id']}**." -            ) +            await send_active_infraction_message(ctx, active_infractions[0])          return active_infractions[0]      else:          log.trace(f"{user} does not have active infractions of type {infr_type}.") +async def send_active_infraction_message(ctx: Context, infraction: Infraction) -> None: +    """Send a message stating that the given infraction is active.""" +    await ctx.send( +        f":x: According to my records, this user already has a {infraction['type']} infraction. " +        f"See infraction **#{infraction['id']}**." +    ) + +  async def notify_infraction( -        user: UserObject, +        user: MemberOrUser,          infr_type: str,          expires_at: t.Optional[str] = None,          reason: t.Optional[str] = None, @@ -189,7 +190,7 @@ async def notify_infraction(  async def notify_pardon( -        user: UserObject, +        user: MemberOrUser,          title: str,          content: str,          icon_url: str = Icons.user_verified @@ -207,7 +208,7 @@ async def notify_pardon(      return await send_private_embed(user, embed) -async def send_private_embed(user: UserObject, embed: discord.Embed) -> bool: +async def send_private_embed(user: MemberOrUser, embed: discord.Embed) -> bool:      """      A helper method for sending an embed to a user's DMs. diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index f19323c7c..2f9083c29 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,11 +10,10 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.bot import Bot  from bot.constants import Event -from bot.converters import Duration, Expiry, FetchedMember +from bot.converters import Duration, Expiry, MemberOrUser  from bot.decorators import respect_role_hierarchy  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler -from bot.exts.moderation.infraction._utils import UserSnowflake  from bot.utils.messages import format_user  log = logging.getLogger(__name__) @@ -54,7 +53,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent infractions      @command() -    async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: +    async def warn(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Warn a user for the given reason."""          if not isinstance(user, Member):              await ctx.send(":x: The user doesn't appear to be on the server.") @@ -67,7 +66,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command() -    async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: +    async def kick(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Kick a user for the given reason."""          if not isinstance(user, Member):              await ctx.send(":x: The user doesn't appear to be on the server.") @@ -79,7 +78,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def ban(          self,          ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -95,7 +94,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def purgeban(          self,          ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -111,7 +110,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def voiceban(          self,          ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] @@ -129,7 +128,7 @@ class Infractions(InfractionScheduler, commands.Cog):      @command(aliases=["mute"])      async def tempmute(          self, ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -163,7 +162,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempban(          self,          ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: Expiry,          *,          reason: t.Optional[str] = None @@ -189,7 +188,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempvoiceban(              self,              ctx: Context, -            user: FetchedMember, +            user: MemberOrUser,              duration: Expiry,              *,              reason: t.Optional[str] @@ -215,7 +214,7 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent shadow infractions      @command(hidden=True) -    async def note(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: +    async def note(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Create a private note for a user with the given reason without notifying the user."""          infraction = await _utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False)          if infraction is None: @@ -224,7 +223,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command(hidden=True, aliases=['shadowban', 'sban']) -    async def shadow_ban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None: +    async def shadow_ban(self, ctx: Context, user: MemberOrUser, *, reason: t.Optional[str] = None) -> None:          """Permanently ban a user for the given reason without notifying the user."""          await self.apply_ban(ctx, user, reason, hidden=True) @@ -235,7 +234,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def shadow_tempban(          self,          ctx: Context, -        user: FetchedMember, +        user: MemberOrUser,          duration: Expiry,          *,          reason: t.Optional[str] = None @@ -261,17 +260,17 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Remove infractions (un- commands)      @command() -    async def unmute(self, ctx: Context, user: FetchedMember) -> None: +    async def unmute(self, ctx: Context, user: MemberOrUser) -> None:          """Prematurely end the active mute infraction for the user."""          await self.pardon_infraction(ctx, "mute", user)      @command() -    async def unban(self, ctx: Context, user: FetchedMember) -> None: +    async def unban(self, ctx: Context, user: MemberOrUser) -> None:          """Prematurely end the active ban infraction for the user."""          await self.pardon_infraction(ctx, "ban", user)      @command(aliases=("uvban",)) -    async def unvoiceban(self, ctx: Context, user: FetchedMember) -> None: +    async def unvoiceban(self, ctx: Context, user: MemberOrUser) -> None:          """Prematurely end the active voice ban infraction for the user."""          await self.pardon_infraction(ctx, "voice_ban", user) @@ -280,8 +279,19 @@ class Infractions(InfractionScheduler, commands.Cog):      async def apply_mute(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None:          """Apply a mute infraction with kwargs passed to `post_infraction`.""" -        if await _utils.get_active_infraction(ctx, user, "mute"): -            return +        if active := await _utils.get_active_infraction(ctx, user, "mute", send_msg=False): +            if active["actor"] != self.bot.user.id: +                await _utils.send_active_infraction_message(ctx, active) +                return + +            # Allow the current mute attempt to override an automatically triggered mute. +            log_text = await self.deactivate_infraction(active, notify=False) +            if "Failure" in log_text: +                await ctx.send( +                    f":x: can't override infraction **mute** for {user.mention}: " +                    f"failed to deactivate. {log_text['Failure']}" +                ) +                return          infraction = await _utils.post_infraction(ctx, user, "mute", reason, active=True, **kwargs)          if infraction is None: @@ -320,7 +330,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def apply_ban(          self,          ctx: Context, -        user: UserSnowflake, +        user: MemberOrUser,          reason: t.Optional[str],          purge_days: t.Optional[int] = 0,          **kwargs @@ -345,7 +355,7 @@ class Infractions(InfractionScheduler, commands.Cog):                  return              log.trace("Old tempban is being replaced by new permaban.") -            await self.pardon_infraction(ctx, "ban", user, is_temporary) +            await self.pardon_infraction(ctx, "ban", user, send_msg=is_temporary)          infraction = await _utils.post_infraction(ctx, user, "ban", reason, active=True, **kwargs)          if infraction is None: @@ -376,7 +386,7 @@ class Infractions(InfractionScheduler, commands.Cog):          await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False)      @respect_role_hierarchy(member_arg=2) -    async def apply_voice_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: +    async def apply_voice_ban(self, ctx: Context, user: MemberOrUser, reason: t.Optional[str], **kwargs) -> None:          """Apply a voice ban infraction with kwargs passed to `post_infraction`."""          if await _utils.get_active_infraction(ctx, user, "voice_ban"):              return @@ -403,8 +413,15 @@ class Infractions(InfractionScheduler, commands.Cog):      # endregion      # region: Base pardon functions -    async def pardon_mute(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: -        """Remove a user's muted role, DM them a notification, and return a log dict.""" +    async def pardon_mute( +        self, +        user_id: int, +        guild: discord.Guild, +        reason: t.Optional[str], +        *, +        notify: bool = True +    ) -> t.Dict[str, str]: +        """Remove a user's muted role, optionally DM them a notification, and return a log dict."""          user = guild.get_member(user_id)          log_text = {} @@ -413,16 +430,17 @@ class Infractions(InfractionScheduler, commands.Cog):              self.mod_log.ignore(Event.member_update, user.id)              await user.remove_roles(self._muted_role, reason=reason) -            # DM the user about the expiration. -            notified = await _utils.notify_pardon( -                user=user, -                title="You have been unmuted", -                content="You may now send messages in the server.", -                icon_url=_utils.INFRACTION_ICONS["mute"][1] -            ) +            if notify: +                # DM the user about the expiration. +                notified = await _utils.notify_pardon( +                    user=user, +                    title="You have been unmuted", +                    content="You may now send messages in the server.", +                    icon_url=_utils.INFRACTION_ICONS["mute"][1] +                ) +                log_text["DM"] = "Sent" if notified else "**Failed**"              log_text["Member"] = format_user(user) -            log_text["DM"] = "Sent" if notified else "**Failed**"          else:              log.info(f"Failed to unmute user {user_id}: user not found")              log_text["Failure"] = "User was not found in the guild." @@ -444,31 +462,39 @@ class Infractions(InfractionScheduler, commands.Cog):          return log_text -    async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: -        """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" +    async def pardon_voice_ban( +        self, +        user_id: int, +        guild: discord.Guild, +        *, +        notify: bool = True +    ) -> t.Dict[str, str]: +        """Optionally DM the user a pardon notification and return a log dict."""          user = guild.get_member(user_id)          log_text = {}          if user: -            # DM user about infraction expiration -            notified = await _utils.notify_pardon( -                user=user, -                title="Voice ban ended", -                content="You have been unbanned and can verify yourself again in the server.", -                icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] -            ) +            if notify: +                # DM user about infraction expiration +                notified = await _utils.notify_pardon( +                    user=user, +                    title="Voice ban ended", +                    content="You have been unbanned and can verify yourself again in the server.", +                    icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] +                ) +                log_text["DM"] = "Sent" if notified else "**Failed**"              log_text["Member"] = format_user(user) -            log_text["DM"] = "Sent" if notified else "**Failed**"          else:              log_text["Info"] = "User was not found in the guild."          return log_text -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: +    async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. +        If `notify` is True, notify the user of the pardon via DM where applicable.          If an infraction type is unsupported, return None instead.          """          guild = self.bot.get_guild(constants.Guild.id) @@ -476,11 +502,11 @@ class Infractions(InfractionScheduler, commands.Cog):          reason = f"Infraction #{infraction['id']} expired or was pardoned."          if infraction["type"] == "mute": -            return await self.pardon_mute(user_id, guild, reason) +            return await self.pardon_mute(user_id, guild, reason, notify=notify)          elif infraction["type"] == "ban":              return await self.pardon_ban(user_id, guild, reason)          elif infraction["type"] == "voice_ban": -            return await self.pardon_voice_ban(user_id, guild, reason) +            return await self.pardon_voice_ban(user_id, guild, notify=notify)      # endregion diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index 3094159cd..641ad0410 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -12,7 +12,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Infraction, Snowflake, UserMention, allowed_strings, proxy_user +from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UserMentionOrID, allowed_strings  from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog  from bot.pagination import LinePaginator @@ -201,29 +201,34 @@ class ModManagement(commands.Cog):      # region: Search infractions      @infraction_group.group(name="search", aliases=('s',), invoke_without_command=True) -    async def infraction_search_group(self, ctx: Context, query: t.Union[UserMention, Snowflake, str]) -> None: +    async def infraction_search_group(self, ctx: Context, query: t.Union[UserMentionOrID, Snowflake, str]) -> None:          """Searches for infractions in the database."""          if isinstance(query, int):              await self.search_user(ctx, discord.Object(query)) -        else: +        elif isinstance(query, str):              await self.search_reason(ctx, query) +        else: +            await self.search_user(ctx, query)      @infraction_search_group.command(name="user", aliases=("member", "id")) -    async def search_user(self, ctx: Context, user: t.Union[discord.User, proxy_user]) -> None: +    async def search_user(self, ctx: Context, user: t.Union[MemberOrUser, discord.Object]) -> None:          """Search for infractions by member."""          infraction_list = await self.bot.api_client.get(              'bot/infractions/expanded',              params={'user__id': str(user.id)}          ) -        user = self.bot.get_user(user.id) -        if not user and infraction_list: -            # Use the user data retrieved from the DB for the username. -            user = infraction_list[0]["user"] -            user = escape_markdown(user["name"]) + f"#{user['discriminator']:04}" +        if isinstance(user, (discord.Member, discord.User)): +            user_str = escape_markdown(str(user)) +        else: +            if infraction_list: +                user = infraction_list[0]["user"] +                user_str = escape_markdown(user["name"]) + f"#{user['discriminator']:04}" +            else: +                user_str = str(user.id)          embed = discord.Embed( -            title=f"Infractions for {user} ({len(infraction_list)} total)", +            title=f"Infractions for {user_str} ({len(infraction_list)} total)",              colour=discord.Colour.orange()          )          await self.send_infraction_list(ctx, embed, infraction_list) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 07e79b9fe..05a2bbe10 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -192,8 +192,8 @@ class Superstarify(InfractionScheduler, Cog):          """Remove the superstarify infraction and allow the user to change their nickname."""          await self.pardon_infraction(ctx, "superstar", member) -    async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]: -        """Pardon a superstar infraction and return a log dict.""" +    async def _pardon_action(self, infraction: _utils.Infraction, notify: bool) -> t.Optional[t.Dict[str, str]]: +        """Pardon a superstar infraction, optionally notify the user via DM, and return a log dict."""          if infraction["type"] != "superstar":              return @@ -208,18 +208,19 @@ class Superstarify(InfractionScheduler, Cog):              )              return {} +        log_text = {"Member": format_user(user)} +          # DM the user about the expiration. -        notified = await _utils.notify_pardon( -            user=user, -            title="You are no longer superstarified", -            content="You may now change your nickname on the server.", -            icon_url=_utils.INFRACTION_ICONS["superstar"][1] -        ) +        if notify: +            notified = await _utils.notify_pardon( +                user=user, +                title="You are no longer superstarified", +                content="You may now change your nickname on the server.", +                icon_url=_utils.INFRACTION_ICONS["superstar"][1] +            ) +            log_text["DM"] = "Sent" if notified else "**Failed**" -        return { -            "Member": format_user(user), -            "DM": "Sent" if notified else "**Failed**" -        } +        return log_text      @staticmethod      def get_nick(infraction_id: int, member_id: int) -> str: diff --git a/bot/exts/moderation/metabase.py b/bot/exts/moderation/metabase.py index e9faf7240..3b454ab18 100644 --- a/bot/exts/moderation/metabase.py +++ b/bot/exts/moderation/metabase.py @@ -42,6 +42,25 @@ class Metabase(Cog):          self.init_task = self.bot.loop.create_task(self.init_cog()) +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Handle ClientResponseError errors locally to invalidate token if needed.""" +        if not isinstance(error.original, ClientResponseError): +            return + +        if error.original.status == 403: +            # User doesn't have access to the given question +            log.warning(f"Failed to auth with Metabase for {error.original.url}.") +            await ctx.send(f":x: {ctx.author.mention} Failed to auth with Metabase for that question.") +        elif error.original.status == 404: +            await ctx.send(f":x: {ctx.author.mention} That question could not be found.") +        else: +            # User credentials are invalid, or the refresh failed. +            # Delete the expiry time, to force a refresh on next startup. +            await self.session_info.delete("session_expiry") +            log.exception("Session token is invalid or refresh failed.") +            await ctx.send(f":x: {ctx.author.mention} Session token is invalid or refresh failed.") +        error.handled = True +      async def init_cog(self) -> None:          """Initialise the metabase session."""          expiry_time = await self.session_info.get("session_expiry") @@ -65,7 +84,7 @@ class Metabase(Cog):              "username": MetabaseConfig.username,              "password": MetabaseConfig.password          } -        async with self.bot.http_session.post(f"{MetabaseConfig.url}/session", json=data) as resp: +        async with self.bot.http_session.post(f"{MetabaseConfig.base_url}/api/session", json=data) as resp:              json_data = await resp.json()              self.session_token = json_data.get("id") @@ -86,7 +105,7 @@ class Metabase(Cog):          """A group of commands for interacting with metabase."""          await ctx.send_help(ctx.command) -    @metabase_group.command(name="extract") +    @metabase_group.command(name="extract", aliases=("export",))      async def metabase_extract(          self,          ctx: Context, @@ -106,48 +125,50 @@ class Metabase(Cog):          Valid extensions are: csv and json.          """ -        async with ctx.typing(): - -            # Make sure we have a session token before running anything -            await self.init_task - -            url = f"{MetabaseConfig.url}/card/{question_id}/query/{extension}" -            try: -                async with self.bot.http_session.post(url, headers=self.headers, raise_for_status=True) as resp: -                    if extension == "csv": -                        out = await resp.text(encoding="utf-8") -                        # Save the output for use with int e -                        self.exports[question_id] = list(csv.DictReader(StringIO(out))) - -                    elif extension == "json": -                        out = await resp.json(encoding="utf-8") -                        # Save the output for use with int e -                        self.exports[question_id] = out - -                        # Format it nicely for human eyes -                        out = json.dumps(out, indent=4, sort_keys=True) -            except ClientResponseError as e: -                if e.status == 403: -                    # User doesn't have access to the given question -                    log.warning(f"Failed to auth with Metabase for question {question_id}.") -                    await ctx.send(f":x: {ctx.author.mention} Failed to auth with Metabase for that question.") -                else: -                    # User credentials are invalid, or the refresh failed. -                    # Delete the expiry time, to force a refresh on next startup. -                    await self.session_info.delete("session_expiry") -                    log.exception("Session token is invalid or refresh failed.") -                    await ctx.send(f":x: {ctx.author.mention} Session token is invalid or refresh failed.") -                return - -            paste_link = await send_to_paste_service(out, extension=extension) -            if paste_link: -                message = f":+1: {ctx.author.mention} Here's your link: {paste_link}" -            else: -                message = f":x: {ctx.author.mention} Link service is unavailible." -            await ctx.send( -                f"{message}\nYou can also access this data within internal eval by doing: " -                f"`bot.get_cog('Metabase').exports[{question_id}]`" -            ) +        await ctx.trigger_typing() + +        # Make sure we have a session token before running anything +        await self.init_task + +        url = f"{MetabaseConfig.base_url}/api/card/{question_id}/query/{extension}" + +        async with self.bot.http_session.post(url, headers=self.headers, raise_for_status=True) as resp: +            if extension == "csv": +                out = await resp.text(encoding="utf-8") +                # Save the output for use with int e +                self.exports[question_id] = list(csv.DictReader(StringIO(out))) + +            elif extension == "json": +                out = await resp.json(encoding="utf-8") +                # Save the output for use with int e +                self.exports[question_id] = out + +                # Format it nicely for human eyes +                out = json.dumps(out, indent=4, sort_keys=True) + +        paste_link = await send_to_paste_service(out, extension=extension) +        if paste_link: +            message = f":+1: {ctx.author.mention} Here's your link: {paste_link}" +        else: +            message = f":x: {ctx.author.mention} Link service is unavailible." +        await ctx.send( +            f"{message}\nYou can also access this data within internal eval by doing: " +            f"`bot.get_cog('Metabase').exports[{question_id}]`" +        ) + +    @metabase_group.command(name="publish", aliases=("share",)) +    async def metabase_publish(self, ctx: Context, question_id: int) -> None: +        """Publically shares the given question and posts the link.""" +        await ctx.trigger_typing() +        # Make sure we have a session token before running anything +        await self.init_task + +        url = f"{MetabaseConfig.base_url}/api/card/{question_id}/public_link" + +        async with self.bot.http_session.post(url, headers=self.headers, raise_for_status=True) as resp: +            response_json = await resp.json(encoding="utf-8") +            sharing_url = f"{MetabaseConfig.base_url}/public/question/{response_json['uuid']}" +            await ctx.send(f":+1: {ctx.author.mention} Here's your sharing link: {sharing_url}")      # This cannot be static (must have a __func__ attribute).      async def cog_check(self, ctx: Context) -> bool: diff --git a/bot/exts/moderation/watchchannels/bigbrother.py b/bot/exts/moderation/watchchannels/bigbrother.py index c6ee844ef..3aa253fea 100644 --- a/bot/exts/moderation/watchchannels/bigbrother.py +++ b/bot/exts/moderation/watchchannels/bigbrother.py @@ -6,7 +6,7 @@ from discord.ext.commands import Cog, Context, group, has_any_role  from bot.bot import Bot  from bot.constants import Channels, MODERATION_ROLES, Webhooks -from bot.converters import FetchedMember +from bot.converters import MemberOrUser  from bot.exts.moderation.infraction._utils import post_infraction  from bot.exts.moderation.watchchannels._watchchannel import WatchChannel @@ -60,7 +60,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):      @bigbrother_group.command(name='watch', aliases=('w',), root_aliases=('watch',))      @has_any_role(*MODERATION_ROLES) -    async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: +    async def watch_command(self, ctx: Context, user: MemberOrUser, *, reason: str) -> None:          """          Relay messages sent by the given `user` to the `#big-brother` channel. @@ -71,11 +71,11 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):      @bigbrother_group.command(name='unwatch', aliases=('uw',), root_aliases=('unwatch',))      @has_any_role(*MODERATION_ROLES) -    async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: +    async def unwatch_command(self, ctx: Context, user: MemberOrUser, *, reason: str) -> None:          """Stop relaying messages by the given `user`."""          await self.apply_unwatch(ctx, user, reason) -    async def apply_watch(self, ctx: Context, user: FetchedMember, reason: str) -> None: +    async def apply_watch(self, ctx: Context, user: MemberOrUser, reason: str) -> None:          """          Add `user` to watched users and apply a watch infraction with `reason`. @@ -94,7 +94,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):              await ctx.send(f":x: {user} is already being watched.")              return -        # FetchedUser instances don't have a roles attribute +        # discord.User instances don't have a roles attribute          if hasattr(user, "roles") and any(role.id in MODERATION_ROLES for role in user.roles):              await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I must be kind to my masters.")              return @@ -125,7 +125,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):          await ctx.send(msg) -    async def apply_unwatch(self, ctx: Context, user: FetchedMember, reason: str, send_message: bool = True) -> None: +    async def apply_unwatch(self, ctx: Context, user: MemberOrUser, reason: str, send_message: bool = True) -> None:          """          Remove `user` from watched users and mark their infraction as inactive with `reason`. diff --git a/bot/exts/recruitment/talentpool/_cog.py b/bot/exts/recruitment/talentpool/_cog.py index 80bd48534..5c1a1cd3f 100644 --- a/bot/exts/recruitment/talentpool/_cog.py +++ b/bot/exts/recruitment/talentpool/_cog.py @@ -6,13 +6,13 @@ from typing import Union  import discord  from async_rediscache import RedisCache -from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent, User +from discord import Color, Embed, Member, PartialMessage, RawReactionActionEvent  from discord.ext.commands import Cog, Context, group, has_any_role  from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES, Webhooks -from bot.converters import FetchedMember +from bot.converters import MemberOrUser  from bot.exts.moderation.watchchannels._watchchannel import WatchChannel  from bot.exts.recruitment.talentpool._review import Reviewer  from bot.pagination import LinePaginator @@ -178,7 +178,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):      @nomination_group.command(name='forcewatch', aliases=('fw', 'forceadd', 'fa'), root_aliases=("forcenominate",))      @has_any_role(*MODERATION_ROLES) -    async def force_watch_command(self, ctx: Context, user: FetchedMember, *, reason: str = '') -> None: +    async def force_watch_command(self, ctx: Context, user: MemberOrUser, *, reason: str = '') -> None:          """          Adds the given `user` to the talent pool, from any channel. @@ -188,7 +188,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):      @nomination_group.command(name='watch', aliases=('w', 'add', 'a'), root_aliases=("nominate",))      @has_any_role(*STAFF_ROLES) -    async def watch_command(self, ctx: Context, user: FetchedMember, *, reason: str = '') -> None: +    async def watch_command(self, ctx: Context, user: MemberOrUser, *, reason: str = '') -> None:          """          Adds the given `user` to the talent pool. @@ -207,7 +207,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          await self._watch_user(ctx, user, reason) -    async def _watch_user(self, ctx: Context, user: FetchedMember, reason: str) -> None: +    async def _watch_user(self, ctx: Context, user: MemberOrUser, reason: str) -> None:          """Adds the given user to the talent pool."""          if user.bot:              await ctx.send(f":x: I'm sorry {ctx.author}, I'm afraid I can't do that. I only watch humans.") @@ -271,7 +271,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):      @nomination_group.command(name='history', aliases=('info', 'search'))      @has_any_role(*MODERATION_ROLES) -    async def history_command(self, ctx: Context, user: FetchedMember) -> None: +    async def history_command(self, ctx: Context, user: MemberOrUser) -> None:          """Shows the specified user's nomination history."""          result = await self.bot.api_client.get(              self.api_endpoint, @@ -300,7 +300,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):      @nomination_group.command(name='unwatch', aliases=('end', ), root_aliases=("unnominate",))      @has_any_role(*MODERATION_ROLES) -    async def unwatch_command(self, ctx: Context, user: FetchedMember, *, reason: str) -> None: +    async def unwatch_command(self, ctx: Context, user: MemberOrUser, *, reason: str) -> None:          """          Ends the active nomination of the specified user with the given reason. @@ -323,7 +323,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):      @nomination_edit_group.command(name='reason')      @has_any_role(*MODERATION_ROLES) -    async def edit_reason_command(self, ctx: Context, nomination_id: int, actor: FetchedMember, *, reason: str) -> None: +    async def edit_reason_command(self, ctx: Context, nomination_id: int, actor: MemberOrUser, *, reason: str) -> None:          """Edits the reason of a specific nominator in a specific active nomination."""          if len(reason) > REASON_MAX_CHARS:              await ctx.send(f":x: Maxiumum allowed characters for the reason is {REASON_MAX_CHARS}.") @@ -417,7 +417,7 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):          await ctx.message.add_reaction(Emojis.check_mark)      @Cog.listener() -    async def on_member_ban(self, guild: Guild, user: Union[User, Member]) -> None: +    async def on_member_ban(self, guild: Guild, user: Union[MemberOrUser]) -> None:          """Remove `user` from the talent pool after they are banned."""          await self.unwatch(user.id, "User was banned.") diff --git a/bot/exts/recruitment/talentpool/_review.py b/bot/exts/recruitment/talentpool/_review.py index 3a1e66970..4d496a1f7 100644 --- a/bot/exts/recruitment/talentpool/_review.py +++ b/bot/exts/recruitment/talentpool/_review.py @@ -15,7 +15,7 @@ from discord.ext.commands import Context  from bot.api import ResponseCodeError  from bot.bot import Bot -from bot.constants import Channels, Colours, Emojis, Guild, Roles +from bot.constants import Channels, Colours, Emojis, Guild  from bot.utils.messages import count_unique_users_reaction, pin_no_system_message  from bot.utils.scheduling import Scheduler  from bot.utils.time import get_time_delta, time_since @@ -33,10 +33,12 @@ MAX_MESSAGE_SIZE = 2000  # Maximum amount of characters allowed in an embed  MAX_EMBED_SIZE = 4000 -# Regex finding the user ID of a user mention -MENTION_RE = re.compile(r"<@!?(\d+?)>") -# Regex matching role pings -ROLE_MENTION_RE = re.compile(r"<@&\d+>") +# Regex for finding the first message of a nomination, and extracting the nominee. +# Historic nominations will have 2 role mentions at the start, new ones won't, optionally match for this. +NOMINATION_MESSAGE_REGEX = re.compile( +    r"(?:<@&\d+> <@&\d+>\n)*?<@!?(\d+?)> \(.+#\d{4}\) for Helper!\n\n\*\*Nominated by:\*\*", +    re.MULTILINE +)  class Reviewer: @@ -118,7 +120,7 @@ class Reviewer:                  f"I tried to review the user with ID `{user_id}`, but they don't appear to be on the server :pensive:"              ), None -        opening = f"<@&{Roles.mod_team}> <@&{Roles.admins}>\n{member.mention} ({member}) for Helper!" +        opening = f"{member.mention} ({member}) for Helper!"          current_nominations = "\n\n".join(              f"**<@{entry['actor']}>:** {entry['reason'] or '*no reason given*'}" @@ -142,14 +144,14 @@ class Reviewer:          """Archive this vote to #nomination-archive."""          message = await message.fetch() -        # We consider the first message in the nomination to contain the two role pings +        # We consider the first message in the nomination to contain the user ping, username#discrim, and fixed text          messages = [message] -        if not len(ROLE_MENTION_RE.findall(message.content)) >= 2: +        if not NOMINATION_MESSAGE_REGEX.search(message.content):              with contextlib.suppress(NoMoreItems):                  async for new_message in message.channel.history(before=message.created_at):                      messages.append(new_message) -                    if len(ROLE_MENTION_RE.findall(new_message.content)) >= 2: +                    if NOMINATION_MESSAGE_REGEX.search(new_message.content):                          break          log.debug(f"Found {len(messages)} messages: {', '.join(str(m.id) for m in messages)}") @@ -161,7 +163,7 @@ class Reviewer:          content = "".join(parts)          # We assume that the first user mentioned is the user that we are voting on -        user_id = int(MENTION_RE.search(content).group(1)) +        user_id = int(NOMINATION_MESSAGE_REGEX.search(content).group(1))          # Get reaction counts          reviewed = await count_unique_users_reaction( diff --git a/bot/exts/utils/jams.py b/bot/exts/utils/jams.py deleted file mode 100644 index 87ae847f6..000000000 --- a/bot/exts/utils/jams.py +++ /dev/null @@ -1,176 +0,0 @@ -import csv -import logging -import typing as t -from collections import defaultdict - -import discord -from discord.ext import commands - -from bot.bot import Bot -from bot.constants import Categories, Channels, Emojis, Roles - -log = logging.getLogger(__name__) - -MAX_CHANNELS = 50 -CATEGORY_NAME = "Code Jam" -TEAM_LEADERS_COLOUR = 0x11806a - - -class CodeJams(commands.Cog): -    """Manages the code-jam related parts of our server.""" - -    def __init__(self, bot: Bot): -        self.bot = bot - -    @commands.group() -    @commands.has_any_role(Roles.admins) -    async def codejam(self, ctx: commands.Context) -> None: -        """A Group of commands for managing Code Jams.""" -        if ctx.invoked_subcommand is None: -            await ctx.send_help(ctx.command) - -    @codejam.command() -    async def create(self, ctx: commands.Context, csv_file: t.Optional[str]) -> None: -        """ -        Create code-jam teams from a CSV file or a link to one, specifying the team names, leaders and members. - -        The CSV file must have 3 columns: 'Team Name', 'Team Member Discord ID', and 'Team Leader'. - -        This will create the text channels for the teams, and give the team leaders their roles. -        """ -        async with ctx.typing(): -            if csv_file: -                async with self.bot.http_session.get(csv_file) as response: -                    if response.status != 200: -                        await ctx.send(f"Got a bad response from the URL: {response.status}") -                        return - -                    csv_file = await response.text() - -            elif ctx.message.attachments: -                csv_file = (await ctx.message.attachments[0].read()).decode("utf8") -            else: -                raise commands.BadArgument("You must include either a CSV file or a link to one.") - -            teams = defaultdict(list) -            reader = csv.DictReader(csv_file.splitlines()) - -            for row in reader: -                member = ctx.guild.get_member(int(row["Team Member Discord ID"])) - -                if member is None: -                    log.trace(f"Got an invalid member ID: {row['Team Member Discord ID']}") -                    continue - -                teams[row["Team Name"]].append((member, row["Team Leader"].upper() == "Y")) - -            team_leaders = await ctx.guild.create_role(name="Code Jam Team Leaders", colour=TEAM_LEADERS_COLOUR) - -            for team_name, members in teams.items(): -                await self.create_team_channel(ctx.guild, team_name, members, team_leaders) - -            await self.create_team_leader_channel(ctx.guild, team_leaders) -            await ctx.send(f"{Emojis.check_mark} Created Code Jam with {len(teams)} teams.") - -    async def get_category(self, guild: discord.Guild) -> discord.CategoryChannel: -        """ -        Return a code jam category. - -        If all categories are full or none exist, create a new category. -        """ -        for category in guild.categories: -            if category.name == CATEGORY_NAME and len(category.channels) < MAX_CHANNELS: -                return category - -        return await self.create_category(guild) - -    async def create_category(self, guild: discord.Guild) -> discord.CategoryChannel: -        """Create a new code jam category and return it.""" -        log.info("Creating a new code jam category.") - -        category_overwrites = { -            guild.default_role: discord.PermissionOverwrite(read_messages=False), -            guild.me: discord.PermissionOverwrite(read_messages=True) -        } - -        category = await guild.create_category_channel( -            CATEGORY_NAME, -            overwrites=category_overwrites, -            reason="It's code jam time!" -        ) - -        await self.send_status_update( -            guild, f"Created a new category with the ID {category.id} for this Code Jam's team channels." -        ) - -        return category - -    @staticmethod -    def get_overwrites( -        members: list[tuple[discord.Member, bool]], -        guild: discord.Guild, -    ) -> dict[t.Union[discord.Member, discord.Role], discord.PermissionOverwrite]: -        """Get code jam team channels permission overwrites.""" -        team_channel_overwrites = { -            guild.default_role: discord.PermissionOverwrite(read_messages=False), -            guild.get_role(Roles.code_jam_event_team): discord.PermissionOverwrite(read_messages=True) -        } - -        for member, _ in members: -            team_channel_overwrites[member] = discord.PermissionOverwrite( -                read_messages=True -            ) - -        return team_channel_overwrites - -    async def create_team_channel( -        self, -        guild: discord.Guild, -        team_name: str, -        members: list[tuple[discord.Member, bool]], -        team_leaders: discord.Role -    ) -> None: -        """Create the team's text channel.""" -        await self.add_team_leader_roles(members, team_leaders) - -        # Get permission overwrites and category -        team_channel_overwrites = self.get_overwrites(members, guild) -        code_jam_category = await self.get_category(guild) - -        # Create a text channel for the team -        await code_jam_category.create_text_channel( -            team_name, -            overwrites=team_channel_overwrites, -        ) - -    async def create_team_leader_channel(self, guild: discord.Guild, team_leaders: discord.Role) -> None: -        """Create the Team Leader Chat channel for the Code Jam team leaders.""" -        category: discord.CategoryChannel = guild.get_channel(Categories.summer_code_jam) - -        team_leaders_chat = await category.create_text_channel( -            name="team-leaders-chat", -            overwrites={ -                guild.default_role: discord.PermissionOverwrite(read_messages=False), -                team_leaders: discord.PermissionOverwrite(read_messages=True) -            } -        ) - -        await self.send_status_update(guild, f"Created {team_leaders_chat.mention} in the {category} category.") - -    async def send_status_update(self, guild: discord.Guild, message: str) -> None: -        """Inform the events lead with a status update when the command is ran.""" -        channel: discord.TextChannel = guild.get_channel(Channels.code_jam_planning) - -        await channel.send(f"<@&{Roles.events_lead}>\n\n{message}") - -    @staticmethod -    async def add_team_leader_roles(members: list[tuple[discord.Member, bool]], team_leaders: discord.Role) -> None: -        """Assign team leader role, the jammer role and their team role.""" -        for member, is_leader in members: -            if is_leader: -                await member.add_roles(team_leaders) - - -def setup(bot: Bot) -> None: -    """Load the CodeJams cog.""" -    bot.add_cog(CodeJams(bot)) diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 7b8c5c4b3..144f7b537 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -12,13 +12,13 @@ from discord.ext.commands import Cog, Context, Greedy, group  from bot.bot import Bot  from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Roles, STAFF_ROLES -from bot.converters import Duration +from bot.converters import Duration, UserMentionOrID  from bot.pagination import LinePaginator  from bot.utils.checks import has_any_role_check, has_no_roles_check  from bot.utils.lock import lock_arg  from bot.utils.messages import send_denial  from bot.utils.scheduling import Scheduler -from bot.utils.time import TimestampFormats, discord_timestamp, time_since +from bot.utils.time import TimestampFormats, discord_timestamp  log = logging.getLogger(__name__) @@ -27,6 +27,7 @@ WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5  Mentionable = t.Union[discord.Member, discord.Role] +ReminderMention = t.Union[UserMentionOrID, discord.Role]  class Reminders(Cog): @@ -172,46 +173,53 @@ class Reminders(Cog):          if not is_valid:              # No need to cancel the task too; it'll simply be done once this coroutine returns.              return -          embed = discord.Embed() -        embed.colour = discord.Colour.blurple() -        embed.set_author( -            icon_url=Icons.remind_blurple, -            name="It has arrived!" -        ) - -        # Let's not use a codeblock to keep emojis and mentions working. Embeds are safe anyway. -        embed.description = f"Here's your reminder: {reminder['content']}." - -        if reminder.get("jump_url"):  # keep backward compatibility -            embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" -          if expected_time:              embed.colour = discord.Colour.red()              embed.set_author(                  icon_url=Icons.remind_red, -                name=f"Sorry it should have arrived {time_since(expected_time)} !" +                name="Sorry, your reminder should have arrived earlier!" +            ) +        else: +            embed.colour = discord.Colour.blurple() +            embed.set_author( +                icon_url=Icons.remind_blurple, +                name="It has arrived!"              ) +        # Let's not use a codeblock to keep emojis and mentions working. Embeds are safe anyway. +        embed.description = f"Here's your reminder: {reminder['content']}" + +        # Here the jump URL is in the format of base_url/guild_id/channel_id/message_id          additional_mentions = ' '.join(              mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"])          ) -        await channel.send(content=f"{user.mention} {additional_mentions}", embed=embed) +        jump_url = reminder.get("jump_url") +        embed.description += f"\n[Jump back to when you created the reminder]({jump_url})" +        partial_message = channel.get_partial_message(int(jump_url.split("/")[-1])) +        try: +            await partial_message.reply(content=f"{additional_mentions}", embed=embed) +        except discord.HTTPException as e: +            log.info( +                f"There was an error when trying to reply to a reminder invocation message, {e}, " +                "fall back to using jump_url" +            ) +            await channel.send(content=f"{user.mention} {additional_mentions}", embed=embed)          log.debug(f"Deleting reminder #{reminder['id']} (the user has been reminded).")          await self.bot.api_client.delete(f"bot/reminders/{reminder['id']}")      @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True)      async def remind_group( -        self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str +        self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str      ) -> None:          """Commands for managing your reminders."""          await self.new_reminder(ctx, mentions=mentions, expiration=expiration, content=content)      @remind_group.command(name="new", aliases=("add", "create"))      async def new_reminder( -        self, ctx: Context, mentions: Greedy[Mentionable], expiration: Duration, *, content: str +        self, ctx: Context, mentions: Greedy[ReminderMention], expiration: Duration, *, content: str      ) -> None:          """          Set yourself a simple reminder. @@ -263,7 +271,7 @@ class Reminders(Cog):              }          ) -        mention_string = f"Your reminder will arrive {discord_timestamp(expiration, TimestampFormats.RELATIVE)}" +        mention_string = f"Your reminder will arrive on {discord_timestamp(expiration, TimestampFormats.DAY_TIME)}"          if mentions:              mention_string += f" and will mention {len(mentions)} other(s)" @@ -356,7 +364,7 @@ class Reminders(Cog):          await self.edit_reminder(ctx, id_, {"content": content})      @edit_reminder_group.command(name="mentions", aliases=("pings",)) -    async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[Mentionable]) -> None: +    async def edit_reminder_mentions(self, ctx: Context, id_: int, mentions: Greedy[ReminderMention]) -> None:          """Edit one of your reminder's mentions."""          # Remove duplicate mentions          mentions = set(mentions) diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 98e43c32b..28c7ec27b 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -14,7 +14,6 @@ from bot.converters import Snowflake  from bot.decorators import in_whitelist  from bot.pagination import LinePaginator  from bot.utils import messages -from bot.utils.checks import has_no_roles_check  from bot.utils.time import time_since  log = logging.getLogger(__name__) @@ -160,9 +159,6 @@ class Utils(Cog):      @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES)      async def snowflake(self, ctx: Context, *snowflakes: Snowflake) -> None:          """Get Discord snowflake creation time.""" -        if len(snowflakes) > 1 and await has_no_roles_check(ctx, *STAFF_ROLES): -            raise BadArgument("Cannot process more than one snowflake in one invocation.") -          if not snowflakes:              raise BadArgument("At least one snowflake must be provided.") diff --git a/bot/pagination.py b/bot/pagination.py index 90d7c84ee..26caa7db0 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -75,7 +75,7 @@ class LinePaginator(Paginator):              raise ValueError(f"scale_to_size must be >= max_size. ({scale_to_size} < {max_size})")          if scale_to_size > 4000: -            raise ValueError(f"scale_to_size must be <= 2,000 characters. ({scale_to_size} > 4000)") +            raise ValueError(f"scale_to_size must be <= 4,000 characters. ({scale_to_size} > 4000)")          self.scale_to_size = scale_to_size - len(suffix)          self.max_lines = max_lines diff --git a/bot/resources/tags/blocking.md b/bot/resources/tags/blocking.md index 31d91294c..5554d7eba 100644 --- a/bot/resources/tags/blocking.md +++ b/bot/resources/tags/blocking.md @@ -1,9 +1,7 @@  **Why do we need asynchronous programming?** -  Imagine that you're coding a Discord bot and every time somebody uses a command, you need to get some information from a database. But there's a catch: the database servers are acting up today and take a whole 10 seconds to respond. If you do **not** use asynchronous methods, your whole bot will stop running until it gets a response from the database. How do you fix this? Asynchronous programming.  **What is asynchronous programming?** -  An asynchronous program utilises the `async` and `await` keywords. An asynchronous program pauses what it's doing and does something else whilst it waits for some third-party service to complete whatever it's supposed to do. Any code within an `async` context manager or function marked with the `await` keyword indicates to Python, that whilst this operation is being completed, it can do something else. For example:  ```py @@ -14,13 +12,10 @@ import discord  async def ping(ctx):      await ctx.send("Pong!")  ``` -  **What does the term "blocking" mean?** -  A blocking operation is wherever you do something without `await`ing it. This tells Python that this step must be completed before it can do anything else. Common examples of blocking operations, as simple as they may seem, include: outputting text, adding two numbers and appending an item onto a list. Most common Python libraries have an asynchronous version available to use in asynchronous contexts.  **`async` libraries** -  The standard async library - `asyncio`  Asynchronous web requests - `aiohttp`  Talking to PostgreSQL asynchronously - `asyncpg` diff --git a/bot/resources/tags/modmail.md b/bot/resources/tags/modmail.md index 412468174..8ac19c8a7 100644 --- a/bot/resources/tags/modmail.md +++ b/bot/resources/tags/modmail.md @@ -6,4 +6,4 @@ It supports attachments, codeblocks, and reactions. As communication happens ove  **To use it, simply send a direct message to the bot.** -Should there be an urgent and immediate need for a moderator or admin to look at a channel, feel free to ping the <@&831776746206265384> or <@&267628507062992896> role instead. +Should there be an urgent and immediate need for a moderator to look at a channel, feel free to ping the <@&831776746206265384> role instead. diff --git a/bot/utils/cache.py b/bot/utils/caching.py index 68ce15607..68ce15607 100644 --- a/bot/utils/cache.py +++ b/bot/utils/caching.py diff --git a/bot/utils/message_cache.py b/bot/utils/message_cache.py new file mode 100644 index 000000000..f68d280c9 --- /dev/null +++ b/bot/utils/message_cache.py @@ -0,0 +1,197 @@ +import typing as t +from math import ceil + +from discord import Message + + +class MessageCache: +    """ +    A data structure for caching messages. + +    The cache is implemented as a circular buffer to allow constant time append, prepend, pop from either side, +    and lookup by index. The cache therefore does not support removal at an arbitrary index (although it can be +    implemented to work in linear time relative to the maximum size). + +    The object additionally holds a mapping from Discord message ID's to the index in which the corresponding message +    is stored, to allow for constant time lookup by message ID. + +    The cache has a size limit operating the same as with a collections.deque, and most of its method names mirror those +    of a deque. + +    The implementation is transparent to the user: to the user the first element is always at index 0, and there are +    only as many elements as were inserted (meaning, without any pre-allocated placeholder values). +    """ + +    def __init__(self, maxlen: int, *, newest_first: bool = False): +        if maxlen <= 0: +            raise ValueError("maxlen must be positive") +        self.maxlen = maxlen +        self.newest_first = newest_first + +        self._start = 0 +        self._end = 0 + +        self._messages: list[t.Optional[Message]] = [None] * self.maxlen +        self._message_id_mapping = {} + +    def append(self, message: Message) -> None: +        """Add the received message to the cache, depending on the order of messages defined by `newest_first`.""" +        if self.newest_first: +            self._appendleft(message) +        else: +            self._appendright(message) + +    def _appendright(self, message: Message) -> None: +        """Add the received message to the end of the cache.""" +        if self._is_full(): +            del self._message_id_mapping[self._messages[self._start].id] +            self._start = (self._start + 1) % self.maxlen + +        self._messages[self._end] = message +        self._message_id_mapping[message.id] = self._end +        self._end = (self._end + 1) % self.maxlen + +    def _appendleft(self, message: Message) -> None: +        """Add the received message to the beginning of the cache.""" +        if self._is_full(): +            self._end = (self._end - 1) % self.maxlen +            del self._message_id_mapping[self._messages[self._end].id] + +        self._start = (self._start - 1) % self.maxlen +        self._messages[self._start] = message +        self._message_id_mapping[message.id] = self._start + +    def pop(self) -> Message: +        """Remove the last message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        self._end = (self._end - 1) % self.maxlen +        message = self._messages[self._end] +        del self._message_id_mapping[message.id] +        self._messages[self._end] = None + +        return message + +    def popleft(self) -> Message: +        """Return the first message in the cache and return it.""" +        if self._is_empty(): +            raise IndexError("pop from an empty cache") + +        message = self._messages[self._start] +        del self._message_id_mapping[message.id] +        self._messages[self._start] = None +        self._start = (self._start + 1) % self.maxlen + +        return message + +    def clear(self) -> None: +        """Remove all messages from the cache.""" +        self._messages = [None] * self.maxlen +        self._message_id_mapping = {} + +        self._start = 0 +        self._end = 0 + +    def get_message(self, message_id: int) -> t.Optional[Message]: +        """Return the message that has the given message ID, if it is cached.""" +        index = self._message_id_mapping.get(message_id, None) +        return self._messages[index] if index is not None else None + +    def update(self, message: Message) -> bool: +        """ +        Update a cached message with new contents. + +        Return True if the given message had a matching ID in the cache. +        """ +        index = self._message_id_mapping.get(message.id, None) +        if index is None: +            return False +        self._messages[index] = message +        return True + +    def __contains__(self, message_id: int) -> bool: +        """Return True if the cache contains a message with the given ID .""" +        return message_id in self._message_id_mapping + +    def __getitem__(self, item: t.Union[int, slice]) -> t.Union[Message, list[Message]]: +        """ +        Return the message(s) in the index or slice provided. + +        This method makes the circular buffer implementation transparent to the user. +        Providing 0 will return the message at the position perceived by the user to be the beginning of the cache, +        meaning at `self._start`. +        """ +        # Keep in mind that for the modulo operator used throughout this function, Python modulo behaves similarly when +        # the left operand is negative. E.g -1 % 5 == 4, because the closest number from the bottom that wholly divides +        # by 5 is -5. +        if isinstance(item, int): +            if item >= len(self) or item < -len(self): +                raise IndexError("cache index out of range") +            return self._messages[(item + self._start) % self.maxlen] + +        elif isinstance(item, slice): +            length = len(self) +            start, stop, step = item.indices(length) + +            # This needs to be checked explicitly now, because otherwise self._start >= self._end is a valid state. +            if (start >= stop and step >= 0) or (start <= stop and step <= 0): +                return [] + +            start = (start + self._start) % self.maxlen +            stop = (stop + self._start) % self.maxlen + +            # Having empty cells is an implementation detail. To the user the cache contains as many elements as they +            # inserted, therefore any empty cells should be ignored. There can only be Nones at the tail. +            if step > 0: +                if ( +                    (self._start < self._end and not self._start < stop <= self._end) +                    or (self._start > self._end and self._end < stop <= self._start) +                ): +                    stop = self._end +            else: +                lower_boundary = (self._start - 1) % self.maxlen +                if ( +                    (self._start < self._end and not self._start - 1 <= stop < self._end) +                    or (self._start > self._end and self._end < stop < lower_boundary) +                ): +                    stop = lower_boundary + +            if (start < stop and step > 0) or (start > stop and step < 0): +                return self._messages[start:stop:step] +            # step != 1 may require a start offset in the second slicing. +            if step > 0: +                offset = ceil((self.maxlen - start) / step) * step + start - self.maxlen +                return self._messages[start::step] + self._messages[offset:stop:step] +            else: +                offset = ceil((start + 1) / -step) * -step - start - 1 +                return self._messages[start::step] + self._messages[self.maxlen - 1 - offset:stop:step] + +        else: +            raise TypeError(f"cache indices must be integers or slices, not {type(item)}") + +    def __iter__(self) -> t.Iterator[Message]: +        if self._is_empty(): +            return + +        if self._start < self._end: +            yield from self._messages[self._start:self._end] +        else: +            yield from self._messages[self._start:] +            yield from self._messages[:self._end] + +    def __len__(self): +        """Get the number of non-empty cells in the cache.""" +        if self._is_empty(): +            return 0 +        if self._end > self._start: +            return self._end - self._start +        return self.maxlen - self._start + self._end + +    def _is_empty(self) -> bool: +        """Return True if the cache has no messages.""" +        return self._messages[self._start] is None + +    def _is_full(self) -> bool: +        """Return True if every cell in the cache already contains a message.""" +        return self._messages[self._end] is not None diff --git a/config-default.yml b/config-default.yml index 881a7df76..eaf8e0ad7 100644 --- a/config-default.yml +++ b/config-default.yml @@ -377,6 +377,8 @@ urls:  anti_spam: +    cache_size: 100 +      # Clean messages that violate a rule.      clean_offending: true      ping_everyone: true @@ -432,14 +434,12 @@ anti_spam:              max: 3 -  metabase: -    username: !ENV "METABASE_USERNAME" -    password: !ENV "METABASE_PASSWORD" -    url: "http://metabase.default.svc.cluster.local/api" +    username: !ENV      "METABASE_USERNAME" +    password: !ENV      "METABASE_PASSWORD" +    base_url:           "http://metabase.default.svc.cluster.local"      # 14 days, see https://www.metabase.com/docs/latest/operations-guide/environment-variables.html#max_session_age -    max_session_age: 20160 - +    max_session_age:    20160  big_brother: diff --git a/tests/bot/exts/events/__init__.py b/tests/bot/exts/events/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/events/__init__.py diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/events/test_code_jams.py index 368a15476..b9ee1e363 100644 --- a/tests/bot/exts/utils/test_jams.py +++ b/tests/bot/exts/events/test_code_jams.py @@ -1,14 +1,15 @@  import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch  from discord import CategoryChannel  from discord.ext.commands import BadArgument  from bot.constants import Roles -from bot.exts.utils import jams +from bot.exts.events import code_jams +from bot.exts.events.code_jams import _channels, _cog  from tests.helpers import (      MockAttachment, MockBot, MockCategoryChannel, MockContext, -    MockGuild, MockMember, MockRole, MockTextChannel +    MockGuild, MockMember, MockRole, MockTextChannel, autospec  )  TEST_CSV = b"""\ @@ -40,7 +41,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):          self.command_user = MockMember([self.admin_role])          self.guild = MockGuild([self.admin_role])          self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) -        self.cog = jams.CodeJams(self.bot) +        self.cog = _cog.CodeJams(self.bot)      async def test_message_without_attachments(self):          """If no link or attachments are provided, commands.BadArgument should be raised.""" @@ -49,7 +50,9 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):          with self.assertRaises(BadArgument):              await self.cog.create(self.cog, self.ctx, None) -    async def test_result_sending(self): +    @patch.object(_channels, "create_team_channel") +    @patch.object(_channels, "create_team_leader_channel") +    async def test_result_sending(self, create_leader_channel, create_team_channel):          """Should call `ctx.send` when everything goes right."""          self.ctx.message.attachments = [MockAttachment()]          self.ctx.message.attachments[0].read = AsyncMock() @@ -61,14 +64,12 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):          self.ctx.guild.create_role = AsyncMock()          self.ctx.guild.create_role.return_value = team_leaders -        self.cog.create_team_channel = AsyncMock() -        self.cog.create_team_leader_channel = AsyncMock()          self.cog.add_roles = AsyncMock()          await self.cog.create(self.cog, self.ctx, None) -        self.cog.create_team_channel.assert_awaited() -        self.cog.create_team_leader_channel.assert_awaited_once_with( +        create_team_channel.assert_awaited() +        create_leader_channel.assert_awaited_once_with(              self.ctx.guild, team_leaders          )          self.ctx.send.assert_awaited_once() @@ -81,25 +82,24 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):          self.ctx.send.assert_awaited_once() -    async def test_category_doesnt_exist(self): +    @patch.object(_channels, "_send_status_update") +    async def test_category_doesnt_exist(self, update):          """Should create a new code jam category."""          subtests = (              [], -            [get_mock_category(jams.MAX_CHANNELS, jams.CATEGORY_NAME)], -            [get_mock_category(jams.MAX_CHANNELS - 2, "other")], +            [get_mock_category(_channels.MAX_CHANNELS, _channels.CATEGORY_NAME)], +            [get_mock_category(_channels.MAX_CHANNELS - 2, "other")],          ) -        self.cog.send_status_update = AsyncMock() -          for categories in subtests: -            self.cog.send_status_update.reset_mock() +            update.reset_mock()              self.guild.reset_mock()              self.guild.categories = categories              with self.subTest(categories=categories): -                actual_category = await self.cog.get_category(self.guild) +                actual_category = await _channels._get_category(self.guild) -                self.cog.send_status_update.assert_called_once() +                update.assert_called_once()                  self.guild.create_category_channel.assert_awaited_once()                  category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] @@ -109,45 +109,41 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):      async def test_category_channel_exist(self):          """Should not try to create category channel.""" -        expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) +        expected_category = get_mock_category(_channels.MAX_CHANNELS - 2, _channels.CATEGORY_NAME)          self.guild.categories = [ -            get_mock_category(jams.MAX_CHANNELS - 2, "other"), +            get_mock_category(_channels.MAX_CHANNELS - 2, "other"),              expected_category, -            get_mock_category(0, jams.CATEGORY_NAME), +            get_mock_category(0, _channels.CATEGORY_NAME),          ] -        actual_category = await self.cog.get_category(self.guild) +        actual_category = await _channels._get_category(self.guild)          self.assertEqual(expected_category, actual_category)      async def test_channel_overwrites(self):          """Should have correct permission overwrites for users and roles."""          leader = (MockMember(), True)          members = [leader] + [(MockMember(), False) for _ in range(4)] -        overwrites = self.cog.get_overwrites(members, self.guild) +        overwrites = _channels._get_overwrites(members, self.guild)          for member, _ in members:              self.assertTrue(overwrites[member].read_messages) -    async def test_team_channels_creation(self): +    @patch.object(_channels, "_get_overwrites") +    @patch.object(_channels, "_get_category") +    @autospec(_channels, "_add_team_leader_roles", pass_mocks=False) +    async def test_team_channels_creation(self, get_category, get_overwrites):          """Should create a text channel for a team."""          team_leaders = MockRole()          members = [(MockMember(), True)] + [(MockMember(), False) for _ in range(5)]          category = MockCategoryChannel()          category.create_text_channel = AsyncMock() -        self.cog.get_overwrites = MagicMock() -        self.cog.get_category = AsyncMock() -        self.cog.get_category.return_value = category -        self.cog.add_team_leader_roles = AsyncMock() - -        await self.cog.create_team_channel(self.guild, "my-team", members, team_leaders) -        self.cog.add_team_leader_roles.assert_awaited_once_with(members, team_leaders) -        self.cog.get_overwrites.assert_called_once_with(members, self.guild) -        self.cog.get_category.assert_awaited_once_with(self.guild) +        get_category.return_value = category +        await _channels.create_team_channel(self.guild, "my-team", members, team_leaders)          category.create_text_channel.assert_awaited_once_with(              "my-team", -            overwrites=self.cog.get_overwrites.return_value +            overwrites=get_overwrites.return_value          )      async def test_jam_roles_adding(self): @@ -156,7 +152,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):          leader = MockMember()          members = [(leader, True)] + [(MockMember(), False) for _ in range(4)] -        await self.cog.add_team_leader_roles(members, leader_role) +        await _channels._add_team_leader_roles(members, leader_role)          leader.add_roles.assert_awaited_once_with(leader_role)          for member, is_leader in members: @@ -170,5 +166,5 @@ class CodeJamSetup(unittest.TestCase):      def test_setup(self):          """Should call `bot.add_cog`."""          bot = MockBot() -        jams.setup(bot) +        code_jams.setup(bot)          bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b9d527770..f844a9181 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):      async def test_voice_unban_user_not_found(self):          """Should include info to return dict when user was not found from guild."""          self.guild.get_member.return_value = None -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {"Info": "User was not found in the guild."})      @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = True          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "Sent" @@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = False          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "**Failed**" diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5f95ced9f..eb256f1fd 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):          test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"])          test_cases = [              test_case([], None, None, True), -            test_case([{"id": 123987}], {"id": 123987}, "123987", False), -            test_case([{"id": 123987}], {"id": 123987}, "123987", True) +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False), +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True)          ]          for case in test_cases: diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): +    """Tests for the MessageCache class in the `bot.utils.caching` module.""" + +    def test_first_append_sets_the_first_value(self): +        """Test if the first append adds the message to the first cell.""" +        cache = MessageCache(maxlen=10) +        message = MockMessage() + +        cache.append(message) + +        self.assertEqual(cache[0], message) + +    def test_append_adds_in_the_right_order(self): +        """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" +        messages = [MockMessage(), MockMessage()] + +        cache = MessageCache(maxlen=10, newest_first=False) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages, list(cache)) + +        cache = MessageCache(maxlen=10, newest_first=True) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages[::-1], list(cache)) + +    def test_appending_over_maxlen_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages.""" +        cache = MessageCache(maxlen=2) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[1:], list(cache)) + +    def test_appending_over_maxlen_with_newest_first_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" +        cache = MessageCache(maxlen=2, newest_first=True) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[:0:-1], list(cache)) + +    def test_pop_removes_from_the_end(self): +        """Test if a pop removes the right-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.pop() + +        self.assertEqual(msg, messages[-1]) +        self.assertListEqual(messages[:-1], list(cache)) + +    def test_popleft_removes_from_the_beginning(self): +        """Test if a popleft removes the left-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.popleft() + +        self.assertEqual(msg, messages[0]) +        self.assertListEqual(messages[1:], list(cache)) + +    def test_clear(self): +        """Test if a clear makes the cache empty.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        cache.clear() + +        self.assertListEqual(list(cache), []) +        self.assertEqual(len(cache), 0) + +    def test_get_message_returns_the_message(self): +        """Test if get_message returns the cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertEqual(cache.get_message(1234), message) + +    def test_get_message_returns_none(self): +        """Test if get_message returns None for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIsNone(cache.get_message(4321)) + +    def test_update_replaces_old_element(self): +        """Test if an update replaced the old message with the same ID.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) +        message = MockMessage(id=1234) +        cache.update(message) + +        self.assertIs(cache.get_message(1234), message) +        self.assertEqual(len(cache), 1) + +    def test_contains_returns_true_for_cached_message(self): +        """Test if contains returns True for an ID of a cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIn(1234, cache) + +    def test_contains_returns_false_for_non_cached_message(self): +        """Test if contains returns False for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertNotIn(4321, cache) + +    def test_indexing(self): +        """Test if the cache returns the correct messages by index.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(5)] + +        for msg in messages: +            cache.append(msg) + +        for current_loop in range(-5, 5): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(cache[current_loop], messages[current_loop]) + +    def test_bad_index_raises_index_error(self): +        """Test if the cache raises IndexError for invalid indices.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] +        test_cases = (-10, -4, 3, 4, 5) + +        for msg in messages: +            cache.append(msg) + +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with self.assertRaises(IndexError): +                    cache[current_loop] + +    def test_slicing_with_unfilled_cache(self): +        """Test if slicing returns the correct messages if the cache is not yet fully filled.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size // 3 * 2)] + +            for msg in messages: +                cache.append(msg) + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_slicing_with_overfilled_cache(self): +        """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size * 3 // 2)] + +            for msg in messages: +                cache.append(msg) +            messages = messages[size // 2:] + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_length(self): +        """Test if len returns the correct number of items in the cache.""" +        cache = MessageCache(maxlen=5) + +        for current_loop in range(10): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(len(cache), min(current_loop, 5)) +                cache.append(MockMessage()) | 
