diff options
| -rw-r--r-- | bot/converters.py | 1 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 5 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/_utils.py | 32 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 42 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 24 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 4 | ||||
| -rw-r--r-- | bot/utils/time.py | 44 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 11 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 29 | 
9 files changed, 133 insertions, 59 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 5800ea044..e97a25bdd 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -574,5 +574,6 @@ if t.TYPE_CHECKING:      Infraction = t.Optional[dict]  # noqa: F811  Expiry = t.Union[Duration, ISODateTime] +DurationOrExpiry = t.Union[DurationDelta, ISODateTime]  MemberOrUser = t.Union[discord.Member, discord.User]  UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser] diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 280b0fb0c..655290559 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -137,8 +137,11 @@ class InfractionScheduler:          infr_type = infraction["type"]          icon = _utils.INFRACTION_ICONS[infr_type][0]          reason = infraction["reason"] -        expiry = time.format_with_duration(infraction["expires_at"])          id_ = infraction['id'] +        expiry = time.format_with_duration( +            infraction["expires_at"], +            infraction["last_applied"] +        )          if user_reason is None:              user_reason = reason diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 3a2485ec2..c03081b07 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -1,5 +1,4 @@  import typing as t -from datetime import datetime  import arrow  import discord @@ -8,10 +7,11 @@ from discord.ext.commands import Context  import bot  from bot.constants import Colours, Icons -from bot.converters import MemberOrUser +from bot.converters import DurationOrExpiry, MemberOrUser  from bot.errors import InvalidInfractedUserError  from bot.log import get_logger  from bot.utils import time +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -44,8 +44,8 @@ LONGEST_EXTRAS = max(len(INFRACTION_APPEAL_SERVER_FOOTER), len(INFRACTION_APPEAL  INFRACTION_DESCRIPTION_TEMPLATE = (      "**Type:** {type}\n" -    "**Expires:** {expires}\n"      "**Duration:** {duration}\n" +    "**Expires:** {expires}\n"      "**Reason:** {reason}\n"  ) @@ -80,7 +80,7 @@ async def post_infraction(          user: MemberOrUser,          infr_type: str,          reason: str, -        expires_at: datetime = None, +        duration_or_expiry: t.Optional[DurationOrExpiry] = None,          hidden: bool = False,          active: bool = True,          dm_sent: bool = False, @@ -92,6 +92,8 @@ async def post_infraction(      log.trace(f"Posting {infr_type} infraction for {user} to the API.") +    current_time = arrow.utcnow() +      payload = {          "actor": ctx.author.id,  # Don't use ctx.message.author; antispam only patches ctx.author.          "hidden": hidden, @@ -99,10 +101,14 @@ async def post_infraction(          "type": infr_type,          "user": user.id,          "active": active, -        "dm_sent": dm_sent +        "dm_sent": dm_sent, +        "inserted_at": current_time.isoformat(), +        "last_applied": current_time.isoformat(),      } -    if expires_at: -        payload['expires_at'] = expires_at.isoformat() + +    if duration_or_expiry is not None: +        _, expiry = unpack_duration(duration_or_expiry, current_time) +        payload["expires_at"] = expiry.isoformat()      # Try to apply the infraction. If it fails because the user doesn't exist, try to add it.      for should_post_user in (True, False): @@ -180,17 +186,17 @@ async def notify_infraction(          expires_at = "Never"          duration = "Permanent"      else: +        origin = arrow.get(infraction["last_applied"])          expiry = arrow.get(infraction["expires_at"])          expires_at = time.format_relative(expiry) -        duration = time.humanize_delta(infraction["inserted_at"], expiry, max_units=2) +        duration = time.humanize_delta(origin, expiry, max_units=2) -        if infraction["active"]: -            remaining = time.humanize_delta(expiry, arrow.utcnow(), max_units=2) -            if duration != remaining: -                duration += f" ({remaining} remaining)" -        else: +        if not infraction["active"]:              expires_at += " (Inactive)" +        if infraction["inserted_at"] != infraction["last_applied"]: +            duration += " (Edited)" +      log.trace(f"Sending {user} a DM about their {infr_type} infraction.")      if reason is None: diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 08a3609a7..05cc74a03 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context, command  from bot import constants  from bot.bot import Bot  from bot.constants import Event -from bot.converters import Age, Duration, Expiry, MemberOrUser, UnambiguousMemberOrUser +from bot.converters import Age, Duration, DurationOrExpiry, MemberOrUser, UnambiguousMemberOrUser  from bot.decorators import ensure_future_timestamp, respect_role_hierarchy  from bot.exts.filters.filtering import AUTO_BAN_DURATION, AUTO_BAN_REASON  from bot.exts.moderation.infraction import _utils @@ -88,16 +88,18 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration_or_expiry: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None:          """ -        Permanently ban a user for the given reason and stop watching them with Big Brother. +        Permanently ban a `user` for the given `reason` and stop watching them with Big Brother. -        If duration is specified, it temporarily bans that user for the given duration. +        If a duration is specified, it temporarily bans the `user` for the given duration. +        Alternatively, an ISO 8601 timestamp representing the expiry time can be provided +        for `duration_or_expiry`.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration_or_expiry)      @command(aliases=("cban", "purgeban", "pban"))      @ensure_future_timestamp(timestamp_arg=3) @@ -105,7 +107,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None: @@ -117,10 +119,10 @@ class Infractions(InfractionScheduler, commands.Cog):          clean_cog: t.Optional[Clean] = self.bot.get_cog("Clean")          if clean_cog is None:              # If we can't get the clean cog, fall back to native purgeban. -            await self.apply_ban(ctx, user, reason, purge_days=1, expires_at=duration) +            await self.apply_ban(ctx, user, reason, purge_days=1, duration_or_expiry=duration)              return -        infraction = await self.apply_ban(ctx, user, reason, expires_at=duration) +        infraction = await self.apply_ban(ctx, user, reason, duration_or_expiry=duration)          if not infraction or not infraction.get("id"):              # Ban was unsuccessful, quit early.              await ctx.send(":x: Failed to apply ban.") @@ -175,7 +177,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str]      ) -> None: @@ -184,7 +186,7 @@ class Infractions(InfractionScheduler, commands.Cog):          If duration is specified, it temporarily voice mutes that user for the given duration.          """ -        await self.apply_voice_mute(ctx, user, reason, expires_at=duration) +        await self.apply_voice_mute(ctx, user, reason, duration_or_expiry=duration)      # endregion      # region: Temporary infractions @@ -194,7 +196,7 @@ class Infractions(InfractionScheduler, commands.Cog):      async def tempmute(          self, ctx: Context,          user: UnambiguousMemberOrUser, -        duration: t.Optional[Expiry] = None, +        duration: t.Optional[DurationOrExpiry] = None,          *,          reason: t.Optional[str] = None      ) -> None: @@ -221,7 +223,7 @@ class Infractions(InfractionScheduler, commands.Cog):          if duration is None:              duration = await Duration().convert(ctx, "1h") -        await self.apply_mute(ctx, user, reason, expires_at=duration) +        await self.apply_mute(ctx, user, reason, duration_or_expiry=duration)      @command(aliases=("tban",))      @ensure_future_timestamp(timestamp_arg=3) @@ -229,7 +231,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration_or_expiry: DurationOrExpiry,          *,          reason: t.Optional[str] = None      ) -> None: @@ -248,7 +250,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration_or_expiry)      @command(aliases=("tempvban", "tvban"))      async def tempvoiceban(self, ctx: Context) -> None: @@ -265,7 +267,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration: DurationOrExpiry,          *,          reason: t.Optional[str]      ) -> None: @@ -284,7 +286,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_voice_mute(ctx, user, reason, expires_at=duration) +        await self.apply_voice_mute(ctx, user, reason, duration_or_expiry=duration)      # endregion      # region: Permanent shadow infractions @@ -312,7 +314,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self,          ctx: Context,          user: UnambiguousMemberOrUser, -        duration: Expiry, +        duration: DurationOrExpiry,          *,          reason: t.Optional[str] = None      ) -> None: @@ -331,7 +333,7 @@ class Infractions(InfractionScheduler, commands.Cog):          Alternatively, an ISO 8601 timestamp can be provided for the duration.          """ -        await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True) +        await self.apply_ban(ctx, user, reason, duration_or_expiry=duration, hidden=True)      # endregion      # region: Remove infractions (un- commands) @@ -435,7 +437,7 @@ class Infractions(InfractionScheduler, commands.Cog):              return None          # In the case of a permanent ban, we don't need get_active_infractions to tell us if one is active -        is_temporary = kwargs.get("expires_at") is not None +        is_temporary = kwargs.get("duration_or_expiry") is not None          active_infraction = await _utils.get_active_infraction(ctx, user, "ban", is_temporary)          if active_infraction: @@ -443,7 +445,7 @@ class Infractions(InfractionScheduler, commands.Cog):                  log.trace("Tempban ignored as it cannot overwrite an active ban.")                  return None -            if active_infraction.get('expires_at') is None: +            if active_infraction.get("duration_or_expiry") is None:                  log.trace("Permaban already exists, notify.")                  await ctx.send(f":x: User is already permanently banned (#{active_infraction['id']}).")                  return None diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index a7d7a844a..6ef382119 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -2,6 +2,7 @@ import re  import textwrap  import typing as t +import arrow  import discord  from discord.ext import commands  from discord.ext.commands import Context @@ -9,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser +from bot.converters import DurationOrExpiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser  from bot.decorators import ensure_future_timestamp  from bot.errors import InvalidInfraction  from bot.exts.moderation.infraction import _utils @@ -20,6 +21,7 @@ from bot.pagination import LinePaginator  from bot.utils import messages, time  from bot.utils.channel import is_mod_channel  from bot.utils.members import get_or_fetch_member +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -89,7 +91,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -129,7 +131,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -172,8 +174,11 @@ class ModManagement(commands.Cog):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: -            request_data['expires_at'] = duration.isoformat() -            expiry = time.format_with_duration(duration) +            origin, expiry = unpack_duration(duration) +            # Update `last_applied` if expiry changes. +            request_data['last_applied'] = origin.isoformat() +            request_data['expires_at'] = expiry.isoformat() +            expiry = time.format_with_duration(expiry, origin)              confirm_messages.append(f"set to expire on {expiry}")          else:              confirm_messages.append("expiry unchanged") @@ -380,7 +385,10 @@ class ModManagement(commands.Cog):          user = infraction["user"]          expires_at = infraction["expires_at"]          inserted_at = infraction["inserted_at"] +        last_applied = infraction["last_applied"]          created = time.discord_timestamp(inserted_at) +        applied = time.discord_timestamp(last_applied) +        duration_edited = arrow.get(last_applied) > arrow.get(inserted_at)          dm_sent = infraction["dm_sent"]          # Format the user string. @@ -400,7 +408,11 @@ class ModManagement(commands.Cog):          if expires_at is None:              duration = "*Permanent*"          else: -            duration = time.humanize_delta(inserted_at, expires_at) +            duration = time.humanize_delta(last_applied, expires_at) + +        # Notice if infraction expiry was edited. +        if duration_edited: +            duration += f" (edited {applied})"          # Format `dm_sent`          if dm_sent is None: diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 0e6aaa1e7..f2aab7a92 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -10,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Duration, Expiry +from bot.converters import Duration, DurationOrExpiry  from bot.decorators import ensure_future_timestamp  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler @@ -109,7 +109,7 @@ class Superstarify(InfractionScheduler, Cog):          self,          ctx: Context,          member: Member, -        duration: t.Optional[Expiry], +        duration: t.Optional[DurationOrExpiry],          *,          reason: str = '',      ) -> None: diff --git a/bot/utils/time.py b/bot/utils/time.py index a0379c3ef..820ac2929 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,12 +1,18 @@ +from __future__ import annotations +  import datetime  import re +from copy import copy  from enum import Enum  from time import struct_time -from typing import Literal, Optional, Union, overload +from typing import Literal, Optional, TYPE_CHECKING, Union, overload  import arrow  from dateutil.relativedelta import relativedelta +if TYPE_CHECKING: +    from bot.converters import DurationOrExpiry +  _DURATION_REGEX = re.compile(      r"((?P<years>\d+?) ?(years|year|Y|y) ?)?"      r"((?P<months>\d+?) ?(months|month|m) ?)?" @@ -194,8 +200,8 @@ def humanize_delta(      elif len(args) <= 2:          end = arrow.get(args[0])          start = arrow.get(args[1]) if len(args) == 2 else arrow.utcnow() +        delta = round_delta(relativedelta(end.datetime, start.datetime)) -        delta = relativedelta(end.datetime, start.datetime)          if absolute:              delta = abs(delta)      else: @@ -326,3 +332,37 @@ def until_expiration(expiry: Optional[Timestamp]) -> str:          return "Expired"      return format_relative(expiry) + + +def unpack_duration( +        duration_or_expiry: DurationOrExpiry, +        origin: Optional[Union[datetime.datetime, arrow.Arrow]] = None +) -> tuple[datetime.datetime, datetime.datetime]: +    """ +    Unpacks a DurationOrExpiry into a tuple of (origin, expiry). + +    The `origin` defaults to the current UTC time at function call. +    """ +    if origin is None: +        origin = datetime.datetime.now(tz=datetime.timezone.utc) + +    if isinstance(origin, arrow.Arrow): +        origin = origin.datetime + +    if isinstance(duration_or_expiry, relativedelta): +        return origin, origin + duration_or_expiry +    else: +        return origin, duration_or_expiry + + +def round_delta(delta: relativedelta) -> relativedelta: +    """ +    Rounds `delta` to the nearest second. + +    Returns a copy with microsecond values of 0. +    """ +    delta = copy(delta) +    if delta.microseconds >= 500000: +        delta += relativedelta(seconds=1) +    delta.microseconds = 0 +    return delta diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index 052048053..a18a4d23b 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -79,13 +79,13 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          """Should call voice mute applying function without expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None) +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry=None)      async def test_temporary_voice_mute(self):          """Should call voice mute applying function with expiry."""          self.cog.apply_voice_mute = AsyncMock()          self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar")) -        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") +        self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", duration_or_expiry="baz")      async def test_voice_unmute(self):          """Should call infraction pardoning function.""" @@ -189,7 +189,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          user = MockUser()          await self.cog.voicemute(self.cog, self.ctx, user, reason=None) -        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None) +        post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, +                                                     duration_or_expiry=None)          apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY)          # Test action @@ -273,7 +274,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.user,              "FooBar",              purge_days=1, -            expires_at=None, +            duration_or_expiry=None,          )      async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self): @@ -285,7 +286,7 @@ class CleanBanTests(unittest.IsolatedAsyncioTestCase):              self.ctx,              self.user,              "FooBar", -            expires_at=None, +            duration_or_expiry=None,          )      @patch("bot.exts.moderation.infraction.infractions.Age") diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 5cf02033d..29dadf372 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -1,7 +1,7 @@  import unittest  from collections import namedtuple  from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch  from botcore.site_api import ResponseCodeError  from discord import Embed, Forbidden, HTTPException, NotFound @@ -309,8 +309,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      async def test_normal_post_infraction(self):          """Should return response from POST request if there are no errors.""" -        now = datetime.now() -        payload = { +        now = datetime.utcnow() +        expected = {              "actor": self.ctx.author.id,              "hidden": True,              "reason": "Test reason", @@ -318,14 +318,17 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):              "user": self.member.id,              "active": False,              "expires_at": now.isoformat(), -            "dm_sent": False +            "dm_sent": False,          }          self.ctx.bot.api_client.post.return_value = "foo"          actual = await utils.post_infraction(self.ctx, self.member, "ban", "Test reason", now, True, False) -          self.assertEqual(actual, "foo") -        self.ctx.bot.api_client.post.assert_awaited_once_with("bot/infractions", json=payload) +        self.ctx.bot.api_client.post.assert_awaited_once() + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        payload: dict = self.ctx.bot.api_client.post.await_args_list[0].kwargs["json"] +        self.assertEqual(payload, payload | expected)      async def test_unknown_error_post_infraction(self):          """Should send an error message to chat when a non-400 error occurs.""" @@ -349,19 +352,25 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):      @patch("bot.exts.moderation.infraction._utils.post_user", return_value="bar")      async def test_first_fail_second_success_user_post_infraction(self, post_user_mock):          """Should post the user if they don't exist, POST infraction again, and return the response if successful.""" -        payload = { +        expected = {              "actor": self.ctx.author.id,              "hidden": False,              "reason": "Test reason",              "type": "mute",              "user": self.user.id,              "active": True, -            "dm_sent": False +            "dm_sent": False,          }          self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"] -          actual = await utils.post_infraction(self.ctx, self.user, "mute", "Test reason")          self.assertEqual(actual, "foo") -        self.bot.api_client.post.assert_has_awaits([call("bot/infractions", json=payload)] * 2) +        await_args = self.bot.api_client.post.await_args_list +        self.assertEqual(len(await_args), 2, "Expected 2 awaits") + +        # Since `last_applied` is based on current time, just check if expected is a subset of payload +        for args in await_args: +            payload: dict = args.kwargs["json"] +            self.assertEqual(payload, payload | expected) +          post_user_mock.assert_awaited_once_with(self.ctx, self.user) | 
