diff options
| -rw-r--r-- | .github/workflows/lint-test.yml | 62 | ||||
| -rw-r--r-- | Dockerfile | 38 | ||||
| -rw-r--r-- | bot/__main__.py | 1 | ||||
| -rw-r--r-- | bot/converters.py | 1 | ||||
| -rw-r--r-- | bot/exts/filters/filtering.py | 2 | ||||
| -rw-r--r-- | bot/exts/moderation/incidents.py | 23 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 15 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_utils.py | 32 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 49 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 24 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 4 | ||||
| -rw-r--r-- | bot/exts/moderation/modlog.py | 2 | ||||
| -rw-r--r-- | bot/rules/mentions.py | 56 | ||||
| -rw-r--r-- | bot/utils/time.py | 44 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 11 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 29 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_incidents.py | 49 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 79 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 58 | ||||
| -rw-r--r-- | tests/helpers.py | 24 | 
20 files changed, 392 insertions, 211 deletions
| diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 2b3dd5b4f..a331659e6 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -33,57 +33,16 @@ jobs:        REDDIT_SECRET: ham        REDIS_PASSWORD: '' -      # Configure pip to cache dependencies and do a user install -      PIP_NO_CACHE_DIR: false -      PIP_USER: 1 - -      # Make sure package manager does not use virtualenv -      POETRY_VIRTUALENVS_CREATE: false - -      # Specify explicit paths for python dependencies and the pre-commit -      # environment so we know which directories to cache -      POETRY_CACHE_DIR: ${{ github.workspace }}/.cache/py-user-base -      PYTHONUSERBASE: ${{ github.workspace }}/.cache/py-user-base -      PRE_COMMIT_HOME: ${{ github.workspace }}/.cache/pre-commit-cache - -      # See https://github.com/pre-commit/pre-commit/issues/2178#issuecomment-1002163763 -      # for why we set this. -      SETUPTOOLS_USE_DISTUTILS: stdlib -      steps: -      - name: Add custom PYTHONUSERBASE to PATH -        run: echo '${{ env.PYTHONUSERBASE }}/bin/' >> $GITHUB_PATH -        - name: Checkout repository          uses: actions/checkout@v2 -      - name: Setup python -        id: python -        uses: actions/setup-python@v2 -        with: -          python-version: '3.10' - -      # This step caches our Python dependencies. To make sure we -      # only restore a cache when the dependencies, the python version, -      # the runner operating system, and the dependency location haven't -      # changed, we create a cache key that is a composite of those states. -      # -      # Only when the context is exactly the same, we will restore the cache. -      - name: Python Dependency Caching -        uses: actions/cache@v2 -        id: python_cache +      - name: Install Python Dependencies +        uses: HassanAbouelela/actions/setup-python@setup-python_v1.3.1          with: -          path: ${{ env.PYTHONUSERBASE }} -          key: "python-0-${{ runner.os }}-${{ env.PYTHONUSERBASE }}-\ -          ${{ steps.python.outputs.python-version }}-\ -          ${{ hashFiles('./pyproject.toml', './poetry.lock') }}" - -      # Install our dependencies if we did not restore a dependency cache -      - name: Install dependencies using poetry -        if: steps.python_cache.outputs.cache-hit != 'true' -        run: | -          pip install poetry -          poetry install +          # Set dev=true to install flake8 extensions, which are dev dependencies +          dev: true +          python_version: '3.10'        # Check all of our non-dev dependencies are compatible with the MIT license.        # If you added a new dependencies that is being rejected, @@ -94,17 +53,6 @@ jobs:            pip-licenses --allow-only="$ALLOWED_LICENSE" \              --package $(poetry export -f requirements.txt --without-hashes | sed "s/==.*//g" | tr "\n" " ") -      # This step caches our pre-commit environment. To make sure we -      # do create a new environment when our pre-commit setup changes, -      # we create a cache key based on relevant factors. -      - name: Pre-commit Environment Caching -        uses: actions/cache@v2 -        with: -          path: ${{ env.PRE_COMMIT_HOME }} -          key: "precommit-0-${{ runner.os }}-${{ env.PRE_COMMIT_HOME }}-\ -          ${{ steps.python.outputs.python-version }}-\ -          ${{ hashFiles('./.pre-commit-config.yaml') }}" -        # We will not run `flake8` here, as we will use a separate flake8        # action. As pre-commit does not support user installs, we set        # PIP_USER=0 to not do a user install. diff --git a/Dockerfile b/Dockerfile index 5bb400658..65ca8ce51 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,28 +1,34 @@  FROM --platform=linux/amd64 python:3.10-slim -# Set pip to have no saved cache -ENV PIP_NO_CACHE_DIR=false \ -    POETRY_VIRTUALENVS_CREATE=false +# Define Git SHA build argument for sentry +ARG git_sha="development" + +# POETRY_VIRTUALENVS_IN_PROJECT is required to ensure in-projects venvs mounted from the host in dev +# don't get prioritised by `poetry run` +ENV POETRY_VERSION=1.2.0 \ +  POETRY_HOME="/opt/poetry/home" \ +  POETRY_CACHE_DIR="/opt/poetry/cache" \ +  POETRY_NO_INTERACTION=1 \ +  POETRY_VIRTUALENVS_IN_PROJECT=false \ +  APP_DIR="/bot" \ +  GIT_SHA=$git_sha +ENV PATH="$POETRY_HOME/bin:$PATH" -# Install poetry -RUN pip install -U poetry +RUN apt-get update \ +  && apt-get -y upgrade \ +  && apt-get install --no-install-recommends -y curl \ +  && apt-get clean && rm -rf /var/lib/apt/lists/* -# Create the working directory -WORKDIR /bot +RUN curl -sSL https://install.python-poetry.org | python  # Install project dependencies +WORKDIR $APP_DIR  COPY pyproject.toml poetry.lock ./ -RUN poetry install --no-dev - -# Define Git SHA build argument -ARG git_sha="development" - -# Set Git SHA environment variable for Sentry -ENV GIT_SHA=$git_sha +RUN poetry install --without dev  # Copy the source code in last to optimize rebuilding the image  COPY . . -ENTRYPOINT ["python3"] -CMD ["-m", "bot"] +ENTRYPOINT ["poetry"] +CMD ["run", "python", "-m", "bot"] diff --git a/bot/__main__.py b/bot/__main__.py index e0d2e6ad5..02af2e9ef 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -26,6 +26,7 @@ async def _create_redis_session() -> RedisSession:          max_connections=20,          use_fakeredis=constants.Redis.use_fakeredis,          global_namespace="bot", +        decode_responses=True,      )      try:          return await redis_session.connect() diff --git a/bot/converters.py b/bot/converters.py index 5800ea044..e97a25bdd 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -574,5 +574,6 @@ if t.TYPE_CHECKING:      Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime] +DurationOrExpiry = t.Union[DurationDelta, ISODateTime]  MemberOrUser = t.Union[discord.Member, discord.User]  UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index ca6ad0064..e4df0b1fd 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -413,7 +413,7 @@ class Filtering(Cog):                              await context.invoke(                                  context.command,                                  msg.author, -                                arrow.utcnow() + AUTO_BAN_DURATION, +                                (arrow.utcnow() + AUTO_BAN_DURATION).datetime,                                  reason=AUTO_BAN_REASON                              ) diff --git a/bot/exts/moderation/incidents.py b/bot/exts/moderation/incidents.py index 155b123ca..1ddbe9857 100644 --- a/bot/exts/moderation/incidents.py +++ b/bot/exts/moderation/incidents.py @@ -1,6 +1,6 @@  import asyncio  import re -from datetime import datetime +from datetime import datetime, timezone  from enum import Enum  from typing import Optional @@ -13,6 +13,7 @@ from bot.bot import Bot  from bot.constants import Channels, Colours, Emojis, Guild, Roles, Webhooks  from bot.log import get_logger  from bot.utils.messages import format_user, sub_clyde +from bot.utils.time import TimestampFormats, discord_timestamp  log = get_logger(__name__) @@ -25,9 +26,9 @@ CRAWL_LIMIT = 50  CRAWL_SLEEP = 2  DISCORD_MESSAGE_LINK_RE = re.compile( -    r"(https?:\/\/(?:(ptb|canary|www)\.)?discord(?:app)?\.com\/channels\/" +    r"(https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"      r"[0-9]{15,20}" -    r"\/[0-9]{15,20}\/[0-9]{15,20})" +    r"/[0-9]{15,20}/[0-9]{15,20})"  ) @@ -97,10 +98,20 @@ async def make_embed(incident: discord.Message, outcome: Signal, actioned_by: di          colour = Colours.soft_red          footer = f"Rejected by {actioned_by}" +    reported_timestamp = discord_timestamp(incident.created_at) +    relative_timestamp = discord_timestamp(incident.created_at, TimestampFormats.RELATIVE) +    reported_on_msg = f"*Reported {reported_timestamp} ({relative_timestamp}).*" + +    # If the description will be too long (>4096 total characters), truncate the incident content +    if len(incident.content) > (allowed_content_chars := 4096-len(reported_on_msg)-2):  # -2 for the newlines +        description = incident.content[:allowed_content_chars-3] + f"...\n\n{reported_on_msg}" +    else: +        description = incident.content + f"\n\n{reported_on_msg}" +      embed = discord.Embed( -        description=incident.content, -        timestamp=datetime.utcnow(), +        description=description,          colour=colour, +        timestamp=datetime.now(timezone.utc)      )      embed.set_footer(text=footer, icon_url=actioned_by.display_avatar.url) @@ -381,7 +392,7 @@ class Incidents(Cog):              webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive)              await webhook.send(                  embed=embed, -                username=sub_clyde(incident.author.name), +                username=sub_clyde(incident.author.display_name),                  avatar_url=incident.author.display_avatar.url,                  file=attachment_file,              ) diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index c7f03b2e9..655290559 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -12,7 +12,7 @@ from discord.ext.commands import Context  from bot import constants  from bot.bot import Bot -from bot.constants import Colours +from bot.constants import Colours, Roles  from bot.converters import MemberOrUser  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.modlog import ModLog @@ -137,8 +137,11 @@ class InfractionScheduler:          infr_type = infraction["type"]          icon = _utils.INFRACTION_ICONS[infr_type][0]          reason = infraction["reason"] -        expiry = time.format_with_duration(infraction["expires_at"])          id_ = infraction['id'] +        expiry = time.format_with_duration( +            infraction["expires_at"], +            infraction["last_applied"] +        )          if user_reason is None:              user_reason = reason @@ -189,7 +192,10 @@ class InfractionScheduler:                  f"Infraction #{id_} actor is bot; including the reason in the confirmation message."              )              if reason: -                end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" +                end_msg = ( +                    f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})." +                    f"\n\nThe <@&{Roles.moderators}> have been alerted for review" +                )          purge = infraction.get("purge", "") @@ -243,7 +249,8 @@ class InfractionScheduler:          # Send a confirmation message to the invoking context.          log.trace(f"Sending infraction #{id_} confirmation message.") -        await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.") +        mentions = discord.AllowedMentions(users=[user], roles=False) +        await ctx.send(f"{dm_result}{confirm_msg}{infr_message}.", allowed_mentions=mentions)          # Send a log message to the mod log.          # Don't use ctx.message.author for the actor; antispam only patches ctx.author. diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 3a2485ec2..c03081b07 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -1,5 +1,4 @@  import typing as t -from datetime import datetime  import arrow  import discord @@ -8,10 +7,11 @@ from discord.ext.commands import Context  import bot  from bot.constants import Colours, Icons -from bot.converters import MemberOrUser +from bot.converters import DurationOrExpiry, MemberOrUser  from bot.errors import InvalidInfractedUserError  from bot.log import get_logger  from bot.utils import time +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -44,8 +44,8 @@ LONGEST_EXTRAS = max(len(INFRACTION_APPEAL_SERVER_FOOTER), len(INFRACTION_APPEAL  INFRACTION_DESCRIPTION_TEMPLATE = (      "**Type:** {type}\n" -    "**Expires:** {expires}\n"      "**Duration:** {duration}\n" +    "**Expires:** {expires}\n"      "**Reason:** {reason}\n"  ) @@ -80,7 +80,7 @@ async def post_infraction(          user: MemberOrUser,          infr_type: str,          reason: str, -        expires_at: datetime = None, +        duration_or_expiry: t.Optional[DurationOrExpiry] = None,          hidden: bool = False,          active: bool = True,          dm_sent: bool = False, @@ -92,6 +92,8 @@ async def post_infraction(      log.trace(f"Posting {infr_type} infraction for {user} to the API.") +    current_time = arrow.utcnow() +      payload = {          "actor": ctx.author.id,  # Don't use ctx.message.author; antispam only patches ctx.author.          "hidden": hidden, @@ -99,10 +101,14 @@ async def post_infraction(          "type": infr_type,          "user": user.id,          "active": active, -        "dm_sent": dm_sent +        "dm_sent": dm_sent, +        "inserted_at": current_time.isoformat(), +        "last_applied": current_time.isoformat(),      } -    if expires_at: -        payload['expires_at'] = expires_at.isoformat() + +    if duration_or_expiry is not None: +        _, expiry = unpack_duration(duration_or_expiry, current_time) +        payload["expires_at"] = expiry.isoformat()      # Try to apply the infraction. If it fails because the user doesn't exist, try to add it.      for should_post_user in (True, False): @@ -180,17 +186,17 @@ async def notify_infraction(          expires_at = "Never"          duration = "Permanent"      else: +        origin = arrow.get(infraction["last_applied"])          expiry = arrow.get(infraction["expires_at"])          expires_at = time.format_relative(expiry) -        duration = time.humanize_delta(infraction["inserted_at"], expiry, max_units=2) +        duration = time.humanize_delta(origin, expiry, max_units=2) -        if infraction["active"]: -            remaining = time.humanize_delta(expiry, arrow.utcnow(), max_units=2) -            if duration != remaining: -                duration += f" ({remaining} remaining)" -        else: +        if not infraction["active"]:              expires_at += " (Inactive)" +        if infraction["inserted_at"] != infraction["last_applied"]: +            duration += " (Edited)" +      log.trace(f"Sending {user} a DM about their {infr_type} infraction.")      if reason is None: diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 46fd3381c..05cc74a03 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -1,6 +1,7 @@  import textwrap  import typing as t +import arrow  import discord  from discord import Member  from discord.ext import commands @@ -9,8 +10,9 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.bot import Bot  from bot.constants import Event -from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser +from bot.converters import Age, Duration, DurationOrExpiry, MemberOrUser, UnambiguousMemberOrUser  from bot.decorators import ensure_future_timestamp, respect_role_hierarchy +from bot.exts.filters.filtering import AUTO_BAN_DURATION, AUTO_BAN_REASON  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.log import get_logger @@ -86,16 +88,18 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration_or_expiry: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None:          """ -        Permanently ban a user for the given reason and stop watching them with Big Brother. +        Permanently ban a `user` for the given `reason` and stop watching them with Big Brother. -        If duration is specified, it temporarily bans that user for the given duration. +        If a duration is specified, it temporarily bans the `user` for the given duration. +        Alternatively, an ISO 8601 timestamp representing the expiry time can be provided +        for `duration_or_expiry`.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration_or_expiry)      @command(aliases=("cban", "purgeban", "pban"))      @ensure_future_timestamp(timestamp_arg=3) @@ -103,7 +107,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None: @@ -115,10 +119,10 @@ class Infractions(InfractionScheduler, commands.Cog):          clean_cog: t.Optional[Clean] = self.bot.get_cog("Clean")          if clean_cog is None:              # If we can't get the clean cog, fall back to native purgeban. -            await self.apply_ban(ctx, user, reason, purge_days=1, expires_at=duration) +            await self.apply_ban(ctx, user, reason, purge_days=1, duration_or_expiry=duration)              return -        infraction = await self.apply_ban(ctx, user, reason, expires_at=duration) +        infraction = await self.apply_ban(ctx, user, reason, duration_or_expiry=duration)          if not infraction or not infraction.get("id"):              # Ban was unsuccessful, quit early.              await ctx.send(":x: Failed to apply ban.") @@ -151,6 +155,11 @@ class Infractions(InfractionScheduler, commands.Cog):          ctx.send = send          await infr_manage_cog.infraction_append(ctx, infraction, None, reason=f"[Clean log]({log_url})") +    @command() +    async def compban(self, ctx: Context, user: UnambiguousMemberOrUser) -> None: +        """Same as cleanban, but specifically with the ban reason and duration used for compromised accounts.""" +        await self.cleanban(ctx, user, duration=(arrow.utcnow() + AUTO_BAN_DURATION).datetime, reason=AUTO_BAN_REASON) +      @command(aliases=("vban",))      async def voiceban(self, ctx: Context) -> None:          """ @@ -168,7 +177,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str]      ) -> None: @@ -177,7 +186,7 @@ class Infractions(InfractionScheduler, commands.Cog):          If duration is specified, it temporarily voice mutes that user for the given duration.          """ -        await self.apply_voice_mute(ctx, user, reason, expires_at=duration) +        await self.apply_voice_mute(ctx, user, reason, duration_or_expiry=duration)      # endregion      # region: Temporary infractions @@ -187,7 +196,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempmute(          self, ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None: @@ -214,7 +223,7 @@ class Infractions(InfractionScheduler, commands.Cog):          if duration is None:              duration = await Duration().convert(ctx, "1h") -        await self.apply_mute(ctx, user, reason, expires_at=duration) +        await self.apply_mute(ctx, user, reason, duration_or_expiry=duration)      @command(aliases=("tban",))      @ensure_future_timestamp(timestamp_arg=3) @@ -222,7 +231,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration_or_expiry: DurationOrExpiry,          *,          reason: t.Optional[str] = None      ) -> None: @@ -241,7 +250,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration_or_expiry)      @command(aliases=("tempvban", "tvban"))      async def tempvoiceban(self, ctx: Context) -> None: @@ -258,7 +267,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration: DurationOrExpiry,          *,          reason: t.Optional[str]      ) -> None: @@ -277,7 +286,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_voice_mute(ctx, user, reason, expires_at=duration) +        await self.apply_voice_mute(ctx, user, reason, duration_or_expiry=duration)      # endregion      # region: Permanent shadow infractions @@ -305,7 +314,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration: DurationOrExpiry,          *,          reason: t.Optional[str] = None      ) -> None: @@ -324,7 +333,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration, hidden=True)      # endregion      # region: Remove infractions (un- commands) @@ -428,7 +437,7 @@ class Infractions(InfractionScheduler, commands.Cog):              return None          # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active -        is_temporary = kwargs.get("expires_at") is not None +        is_temporary = kwargs.get("duration_or_expiry") is not None          active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary)          if active_infraction: @@ -436,7 +445,7 @@ class Infractions(InfractionScheduler, commands.Cog):                  log.trace("Tempban ignored as it cannot overwrite an active ban.")                  return None -            if active_infraction.get('expires_at') is None: +            if active_infraction.get("duration_or_expiry") is None:                  log.trace("Permaban already exists, notify.")                  await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).")                  return None diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index a7d7a844a..6ef382119 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -2,6 +2,7 @@ import re  import textwrap  import typing as t +import arrow  import discord  from discord.ext import commands  from discord.ext.commands import Context @@ -9,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser +from bot.converters import DurationOrExpiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser  from bot.decorators import ensure_future_timestamp  from bot.errors import InvalidInfraction  from bot.exts.moderation.infraction import _utils @@ -20,6 +21,7 @@ from bot.pagination import LinePaginator  from bot.utils import messages, time  from bot.utils.channel import is_mod_channel  from bot.utils.members import get_or_fetch_member +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -89,7 +91,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -129,7 +131,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -172,8 +174,11 @@ class ModManagement(commands.Cog):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: -            request_data['expires_at'] = duration.isoformat() -            expiry = time.format_with_duration(duration) +            origin, expiry = unpack_duration(duration) +            # Update `last_applied` if expiry changes. +            request_data['last_applied'] = origin.isoformat() +            request_data['expires_at'] = expiry.isoformat() +            expiry = time.format_with_duration(expiry, origin)              confirm_messages.append(f"set to expire on {expiry}")          else:              confirm_messages.append("expiry unchanged") @@ -380,7 +385,10 @@ class ModManagement(commands.Cog):          user = infraction["user"]          expires_at = infraction["expires_at"]          inserted_at = infraction["inserted_at"] +        last_applied = infraction["last_applied"]          created = time.discord_timestamp(inserted_at) +        applied = time.discord_timestamp(last_applied) +        duration_edited = arrow.get(last_applied) > arrow.get(inserted_at)          dm_sent = infraction["dm_sent"]          # Format the user string. @@ -400,7 +408,11 @@ class ModManagement(commands.Cog):          if expires_at is None:              duration = "*Permanent*"          else: -            duration = time.humanize_delta(inserted_at, expires_at) +            duration = time.humanize_delta(last_applied, expires_at) + +        # Notice if infraction expiry was edited. +        if duration_edited: +            duration += f" (edited {applied})"          # Format `dm_sent`          if dm_sent is None: diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 0e6aaa1e7..f2aab7a92 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -10,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Duration, Expiry +from bot.converters import Duration, DurationOrExpiry  from bot.decorators import ensure_future_timestamp  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler @@ -109,7 +109,7 @@ class Superstarify(InfractionScheduler, Cog):          self,          ctx: Context,          member: Member, -        duration: t.Optional[Expiry], +        duration: t.Optional[DurationOrExpiry],          *,          reason: str = '',      ) -> None: diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 67991730e..efa87ce25 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -552,7 +552,7 @@ class ModLog(Cog, name="ModLog"):          channel = self.bot.get_channel(channel_id)          # Ignore not found channels, DMs, and messages outside of the main guild. -        if not channel or not hasattr(channel, "guild") or channel.guild.id != GuildConstant.id: +        if not channel or channel.guild is None or channel.guild.id != GuildConstant.id:              return True          # Look at the parent channel of a thread. diff --git a/bot/rules/mentions.py b/bot/rules/mentions.py index 6f5addad1..ca1d0c01c 100644 --- a/bot/rules/mentions.py +++ b/bot/rules/mentions.py @@ -1,23 +1,65 @@  from typing import Dict, Iterable, List, Optional, Tuple -from discord import Member, Message +from discord import DeletedReferencedMessage, Member, Message, MessageType, NotFound + +import bot +from bot.log import get_logger + +log = get_logger(__name__)  async def apply(      last_message: Message, recent_messages: List[Message], config: Dict[str, int]  ) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: -    """Detects total mentions exceeding the limit sent by a single user.""" +    """ +    Detects total mentions exceeding the limit sent by a single user. + +    Excludes mentions that are bots, themselves, or replied users. + +    In very rare cases, may not be able to determine a +    mention was to a reply, in which case it is not ignored. +    """      relevant_messages = tuple(          msg          for msg in recent_messages          if msg.author == last_message.author      ) +    # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned. +    # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body. +    # In order to exclude users who are mentioned as a reply, we check if the msg has a reference +    # +    # While we could use regex to parse the message content, and get a list of +    # the mentions, that solution is very prone to breaking. +    # We would need to deal with codeblocks, escaping markdown, and any discrepancies between +    # our implementation and discord's markdown parser which would cause false positives or false negatives. +    total_recent_mentions = 0 +    for msg in relevant_messages: +        # We check if the message is a reply, and if it is try to get the author +        # since we ignore mentions of a user that we're replying to +        reply_author = None -    total_recent_mentions = sum( -        not user.bot -        for msg in relevant_messages -        for user in msg.mentions -    ) +        if msg.type == MessageType.reply: +            ref = msg.reference + +            if not (resolved := ref.resolved): +                # It is possible, in a very unusual situation, for a message to have a reference +                # that is both not in the cache, and deleted while running this function. +                # In such a situation, this will throw an error which we catch. +                try: +                    resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message( +                        resolved.message_id +                    ) +                except NotFound: +                    log.info('Could not fetch the reference message as it has been deleted.') + +            if resolved and not isinstance(resolved, DeletedReferencedMessage): +                reply_author = resolved.author + +        for user in msg.mentions: +            # Don't count bot or self mentions, or the user being replied to (if applicable) +            if user.bot or user in {msg.author, reply_author}: +                continue +            total_recent_mentions += 1      if total_recent_mentions > config['max']:          return ( diff --git a/bot/utils/time.py b/bot/utils/time.py index a0379c3ef..820ac2929 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,12 +1,18 @@ +from __future__ import annotations +  import datetime  import re +from copy import copy  from enum import Enum  from time import struct_time -from typing import Literal, Optional, Union, overload +from typing import Literal, Optional, TYPE_CHECKING, Union, overload  import arrow  from dateutil.relativedelta import relativedelta +if TYPE_CHECKING: +    from bot.converters import DurationOrExpiry +  _DURATION_REGEX = re.compile(      r"((?P<years>\d+?) ?(years|year|Y|y) ?)?"      r"((?P<months>\d+?) ?(months|month|m) ?)?" @@ -194,8 +200,8 @@ def humanize_delta(      elif len(args) <= 2:          end = arrow.get(args[0])          start = arrow.get(args[1]) if len(args) == 2 else arrow.utcnow() +        delta = round_delta(relativedelta(end.datetime, start.datetime)) -        delta = relativedelta(end.datetime, start.datetime)          if absolute:              delta = abs(delta)      else: @@ -326,3 +332,37 @@ def until_expiration(expiry: Optional[Timestamp]) -> str:          return "Expired"      return format_relative(expiry) + + +def unpack_duration( +        duration_or_expiry: DurationOrExpiry, +        origin: Optional[Union[datetime.datetime, arrow.Arrow]] = None +) -> tuple[datetime.datetime, datetime.datetime]: +    """ +    Unpacks a DurationOrExpiry into a tuple of (origin, expiry). + +    The `origin` defaults to the current UTC time at function call. +    """ +    if origin is None: +        origin = datetime.datetime.now(tz=datetime.timezone.utc) + +    if isinstance(origin, arrow.Arrow): +        origin = origin.datetime + +    if isinstance(duration_or_expiry, relativedelta): +        return origin, origin + duration_or_expiry +    else: +        return origin, duration_or_expiry + + +def round_delta(delta: relativedelta) -> relativedelta: +    """ +    Rounds `delta` to the nearest second. + +    Returns a copy with microsecond values of 0. +    """ +    delta = copy(delta) +    if delta.microseconds >= 500000: +        delta += relativedelta(seconds=1) +    delta.microseconds = 0 +    return delta diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..a18a4d23b 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -79,13 +79,13 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          """Should call voice mute applying function without expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None)      async def test_temporary_voice_mute(self):          """Should call voice mute applying function with expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz")      async def test_voice_unmute(self):          """Should call infraction pardoning function.""" @@ -189,7 +189,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          user = MockUser()          await self.cog.voicemute(self.cog, self.ctx, user, reason=None) -        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) +        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, +                                                     duration_or_expiry=None)          apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY)          # Test action @@ -273,7 +274,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.user,              "FooBar",              purge_days=1, -            expires_at=None, +            duration_or_expiry=None,          )      async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): @@ -285,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.ctx,              self.user,              "FooBar", -            expires_at=None, +            duration_or_expiry=None,          )      @patch("bot.exts.moderation.infraction.infractions.Age") diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5cf02033d..29dadf372 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -1,7 +1,7 @@  import unittest  from collections import namedtuple  from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch  from botcore.site_api import ResponseCodeError  from discord import Embed, Forbidden, HTTPException, NotFound @@ -309,8 +309,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      async def test_normal_post_infraction(self):          """Should return response from POST request if there are no errors.""" -        now = datetime.now() -        payload = { +        now = datetime.utcnow() +        expected = {              "actor": self.ctx.author.id,              "hidden": True,              "reason": "Test reason", @@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):              "user": self.member.id,              "active": False,              "expires_at": now.isoformat(), -            "dm_sent": False +            "dm_sent": False,          }          self.ctx.bot.api_client.post.return_value = "foo"          actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) -          self.assertEqual(actual, "foo") -        self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) +        self.ctx.bot.api_client.post.assert_awaited_once() + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"] +        self.assertEqual(payload, payload | expected)      async def test_unknown_error_post_infraction(self):          """Should send an error message to chat when a non-400 error occurs.""" @@ -349,19 +352,25 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar")      async def test_first_fail_second_success_user_post_infraction(self, post_user_mock):          """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" -        payload = { +        expected = {              "actor": self.ctx.author.id,              "hidden": False,              "reason": "Test reason",              "type": "mute",              "user": self.user.id,              "active": True, -            "dm_sent": False +            "dm_sent": False,          }          self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] -          actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason")          self.assertEqual(actual, "foo") -        self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) +        await_args = self.bot.api_client.post.await_args_list +        self.assertEqual(len(await_args), 2, "Expected 2 awaits") + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        for args in await_args: +            payload: dict = args.kwargs["json"] +            self.assertEqual(payload, payload | expected) +          post_user_mock.assert_awaited_once_with(self.ctx, self.user) diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py index 97682163f..53d98360c 100644 --- a/tests/bot/exts/moderation/test_incidents.py +++ b/tests/bot/exts/moderation/test_incidents.py @@ -1,4 +1,5 @@  import asyncio +import datetime  import enum  import logging  import typing as t @@ -12,12 +13,15 @@ import discord  from bot.constants import Colours  from bot.exts.moderation import incidents  from bot.utils.messages import format_user +from bot.utils.time import TimestampFormats, discord_timestamp  from tests.base import RedisTestCase  from tests.helpers import (      MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,      MockUser  ) +CURRENT_TIME = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) +  class MockAsyncIterable:      """ @@ -100,30 +104,45 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):      async def test_make_embed_actioned(self):          """Embed is coloured green and footer contains 'Actioned' when `outcome=Signal.ACTIONED`.""" -        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.ACTIONED, MockMember()) +        embed, file = await incidents.make_embed( +            incident=MockMessage(created_at=CURRENT_TIME), +            outcome=incidents.Signal.ACTIONED, +            actioned_by=MockMember() +        )          self.assertEqual(embed.colour.value, Colours.soft_green)          self.assertIn("Actioned", embed.footer.text)      async def test_make_embed_not_actioned(self):          """Embed is coloured red and footer contains 'Rejected' when `outcome=Signal.NOT_ACTIONED`.""" -        embed, file = await incidents.make_embed(MockMessage(), incidents.Signal.NOT_ACTIONED, MockMember()) +        embed, file = await incidents.make_embed( +            incident=MockMessage(created_at=CURRENT_TIME), +            outcome=incidents.Signal.NOT_ACTIONED, +            actioned_by=MockMember() +        )          self.assertEqual(embed.colour.value, Colours.soft_red)          self.assertIn("Rejected", embed.footer.text)      async def test_make_embed_content(self):          """Incident content appears as embed description.""" -        incident = MockMessage(content="this is an incident") +        incident = MockMessage(content="this is an incident", created_at=CURRENT_TIME) + +        reported_timestamp = discord_timestamp(CURRENT_TIME) +        relative_timestamp = discord_timestamp(CURRENT_TIME, TimestampFormats.RELATIVE) +          embed, file = await incidents.make_embed(incident, incidents.Signal.ACTIONED, MockMember()) -        self.assertEqual(incident.content, embed.description) +        self.assertEqual( +            f"{incident.content}\n\n*Reported {reported_timestamp} ({relative_timestamp}).*", +            embed.description +        )      async def test_make_embed_with_attachment_succeeds(self):          """Incident's attachment is downloaded and displayed in the embed's image field."""          file = MagicMock(discord.File, filename="bigbadjoe.jpg")          attachment = MockAttachment(filename="bigbadjoe.jpg") -        incident = MockMessage(content="this is an incident", attachments=[attachment]) +        incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)          # Patch `download_file` to return our `file`          with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=file)): @@ -135,7 +154,7 @@ class TestMakeEmbed(unittest.IsolatedAsyncioTestCase):      async def test_make_embed_with_attachment_fails(self):          """Incident's attachment fails to download, proxy url is linked instead."""          attachment = MockAttachment(proxy_url="discord.com/bigbadjoe.jpg") -        incident = MockMessage(content="this is an incident", attachments=[attachment]) +        incident = MockMessage(content="this is an incident", attachments=[attachment], created_at=CURRENT_TIME)          # Patch `download_file` to return None as if the download failed          with patch("bot.exts.moderation.incidents.download_file", AsyncMock(return_value=None)): @@ -349,7 +368,6 @@ class TestCrawlIncidents(TestIncidents):  class TestArchive(TestIncidents):      """Tests for the `Incidents.archive` coroutine.""" -      async def test_archive_webhook_not_found(self):          """          Method recovers and returns False when the webhook is not found. @@ -359,7 +377,11 @@ class TestArchive(TestIncidents):          """          self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404)          self.assertFalse( -            await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock(), actioned_by=MockMember()) +            await self.cog_instance.archive( +                incident=MockMessage(created_at=CURRENT_TIME), +                outcome=MagicMock(), +                actioned_by=MockMember() +            )          )      async def test_archive_relays_incident(self): @@ -375,7 +397,7 @@ class TestArchive(TestIncidents):          # Define our own `incident` to be archived          incident = MockMessage(              content="this is an incident", -            author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")), +            author=MockUser(display_name="author_name", display_avatar=Mock(url="author_avatar")),              id=123,          )          built_embed = MagicMock(discord.Embed, id=123)  # We patch `make_embed` to return this @@ -406,7 +428,7 @@ class TestArchive(TestIncidents):          webhook = MockAsyncWebhook()          self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) -        message_from_clyde = MockMessage(author=MockUser(name="clyde the great")) +        message_from_clyde = MockMessage(author=MockUser(display_name="clyde the great"), created_at=CURRENT_TIME)          await self.cog_instance.archive(message_from_clyde, MagicMock(incidents.Signal), MockMember())          self.assertNotIn("clyde", webhook.send.call_args.kwargs["username"]) @@ -505,12 +527,13 @@ class TestProcessEvent(TestIncidents):      async def test_process_event_confirmation_task_is_awaited(self):          """Task given by `Incidents.make_confirmation_task` is awaited before method exits."""          mock_task = AsyncMock() +        mock_member = MockMember(display_name="Bobby Johnson", roles=[MockRole(id=1)])          with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):              await self.cog_instance.process_event(                  reaction=incidents.Signal.ACTIONED.value, -                incident=MockMessage(id=123), -                member=MockMember(roles=[MockRole(id=1)]) +                incident=MockMessage(author=mock_member, id=123, created_at=CURRENT_TIME), +                member=mock_member              )          mock_task.assert_awaited() @@ -529,7 +552,7 @@ class TestProcessEvent(TestIncidents):              with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):                  await self.cog_instance.process_event(                      reaction=incidents.Signal.ACTIONED.value, -                    incident=MockMessage(id=123), +                    incident=MockMessage(id=123, created_at=CURRENT_TIME),                      member=MockMember(roles=[MockRole(id=1)])                  )          except asyncio.TimeoutError: diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 98547e2bc..2622f46a7 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,4 +1,3 @@ -import asyncio  import itertools  import unittest  from datetime import datetime, timezone @@ -23,8 +22,24 @@ class PatchedDatetime(datetime):      now = mock.create_autospec(datetime, "now") -class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase): +class SilenceTest(RedisTestCase): +    """A base class for Silence tests that correctly sets up the cog and redis.""" + +    @autospec(silence, "Scheduler", pass_mocks=False) +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id)) +        self.cog = silence.Silence(self.bot) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes. + + +class SilenceNotifierTests(SilenceTest):      def setUp(self) -> None: +        super().setUp()          self.alert_channel = MockTextChannel()          self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock() @@ -89,34 +104,24 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceCogTests(RedisTestCase): +class SilenceCogTests(SilenceTest):      """Tests for the general functionality of the Silence cog.""" -    @autospec(silence, "Scheduler", pass_mocks=False) -    def setUp(self) -> None: -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog.cog_load()          self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id)      @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_channels(self):          """Got channels from bot.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)      @autospec(silence, "SilenceNotifier")      async def test_cog_load_got_notifier(self, notifier):          """Notifier was started with channel.""" -        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) -          await self.cog.cog_load()          notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))          self.assertEqual(self.cog.notifier, notifier.return_value) @@ -229,13 +234,9 @@ class SilenceCogTests(RedisTestCase):              self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(RedisTestCase): +class SilenceArgumentParserTests(SilenceTest):      """Tests for the silence argument parser utility function.""" -    def setUp(self): -        self.bot = MockBot() -        self.cog = silence.Silence(self.bot) -      @autospec(silence.Silence, "send_message", pass_mocks=False)      @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)      @autospec(silence.Silence, "parse_silence_args") @@ -303,17 +304,19 @@ class SilenceArgumentParserTests(RedisTestCase):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class RescheduleTests(unittest.IsolatedAsyncioTestCase): +class RescheduleTests(RedisTestCase):      """Tests for the rescheduling of cached unsilences.""" -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) -    def setUp(self): +    @autospec(silence, "Scheduler", pass_mocks=False) +    def setUp(self) -> None:          self.bot = MockBot()          self.cog = silence.Silence(self.bot)          self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) -        with mock.patch.object(self.cog, "_reschedule", autospec=True): -            asyncio.run(self.cog.cog_load())  # Populate instance attributes. +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        await self.cog.cog_load()  # Populate instance attributes.      async def test_skipped_missing_channel(self):          """Did nothing because the channel couldn't be retrieved.""" @@ -388,20 +391,14 @@ def voice_sync_helper(function):  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) -class SilenceTests(RedisTestCase): +class SilenceTests(SilenceTest):      """Tests for the silence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) +        super().setUp()          # Avoid unawaited coroutine warnings.          self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. -          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(              send_messages=True, @@ -659,22 +656,13 @@ class SilenceTests(RedisTestCase):  @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) -class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +class UnsilenceTests(SilenceTest):      """Tests for the unsilence command and its related helper methods.""" -    @autospec(silence.Silence, "_reschedule", pass_mocks=False) -    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) -        self.cog = silence.Silence(self.bot) - -        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) -        self.cog.previous_overwrites = overwrites_cache - -        asyncio.run(self.cog.cog_load())  # Populate instance attributes. +        super().setUp()          self.cog.scheduler.__contains__.return_value = True -        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'          self.text_channel = MockTextChannel()          self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False)          self.text_channel.overwrites_for.return_value = self.text_overwrite @@ -683,6 +671,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)          self.voice_channel.overwrites_for.return_value = self.voice_overwrite +    async def asyncSetUp(self) -> None: +        await super().asyncSetUp() +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +      async def test_sent_correct_message(self):          """Appropriate failure/success message was sent by the command."""          unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index f8805ac48..e1f904917 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,15 +1,32 @@ -from typing import Iterable +from typing import Iterable, Optional + +import discord  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage +from tests.helpers import MockMember, MockMessage, MockMessageReference -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: -    """Makes a message with `total_mentions` mentions.""" +def make_msg( +    author: str, +    total_user_mentions: int, +    total_bot_mentions: int = 0, +    *, +    reference: Optional[MockMessageReference] = None +) -> MockMessage: +    """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""      user_mentions = [MockMember() for _ in range(total_user_mentions)]      bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] -    return MockMessage(author=author, mentions=user_mentions+bot_mentions) + +    mentions = user_mentions + bot_mentions +    if reference is not None: +        # For the sake of these tests we assume that all references are mentions. +        mentions.append(reference.resolved.author) +        msg_type = discord.MessageType.reply +    else: +        msg_type = discord.MessageType.default + +    return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)  class TestMentions(RuleTest): @@ -56,6 +73,16 @@ class TestMentions(RuleTest):                  ("bob",),                  3,              ), +            DisallowedCase( +                [make_msg("bob", 3, reference=MockMessageReference())], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], +                ("bob",), +                3 +            )          )          await self.run_disallowed(cases) @@ -71,6 +98,27 @@ class TestMentions(RuleTest):          await self.run_allowed(cases) +    async def test_ignore_reply_mentions(self): +        """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" +        cases = ( +            [ +                make_msg("bob", 2, reference=MockMessageReference()) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference()), +                make_msg("bob", 0, reference=MockMessageReference()) +            ], +            [ +                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), +                make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) +            ] +        ) + +        await self.run_allowed(cases) +      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          last_message = case.recent_messages[0]          return tuple( diff --git a/tests/helpers.py b/tests/helpers.py index 17214553c..a4b919dcb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -317,7 +317,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):          guild_id=1,          intents=discord.Intents.all(),      ) -    additional_spec_asyncs = ("wait_for", "redis_ready") +    additional_spec_asyncs = ("wait_for",)      def __init__(self, **kwargs) -> None:          super().__init__(**kwargs) @@ -492,6 +492,28 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):      spec_set = attachment_instance +message_reference_instance = discord.MessageReference( +    message_id=unittest.mock.MagicMock(id=1), +    channel_id=unittest.mock.MagicMock(id=2), +    guild_id=unittest.mock.MagicMock(id=3) +) + + +class MockMessageReference(CustomMockMixin, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock MessageReference objects. + +    Instances of this class will follow the specification of `discord.MessageReference` instances. +    For more information, see the `MockGuild` docstring. +    """ +    spec_set = message_reference_instance + +    def __init__(self, *, reference_author_is_bot: bool = False, **kwargs): +        super().__init__(**kwargs) +        referenced_msg_author = MockMember(name="bob", bot=reference_author_is_bot) +        self.resolved = MockMessage(author=referenced_msg_author) + +  class MockMessage(CustomMockMixin, unittest.mock.MagicMock):      """      A MagicMock subclass to mock Message objects. | 
