diff options
| -rw-r--r-- | bot/exts/moderation/infraction/_utils.py | 9 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/management.py | 24 | ||||
| -rw-r--r-- | bot/utils/time.py | 49 | 
3 files changed, 59 insertions, 23 deletions
| diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 407c97251..5e708d7fe 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -12,6 +12,7 @@ from bot.converters import DurationOrExpiry, MemberOrUser  from bot.errors import InvalidInfractedUserError  from bot.log import get_logger  from bot.utils import time +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -102,15 +103,13 @@ async def post_infraction(          "user": user.id,          "active": active,          "dm_sent": dm_sent, +        "inserted_at": current_time.isoformat(),          "last_applied": current_time.isoformat(),      } -    # Parse duration or expiry      if duration_or_expiry is not None: -        if isinstance(duration_or_expiry, datetime): -            payload['expires_at'] = duration_or_expiry.isoformat() -        else:  # is relativedelta -            payload['expires_at'] = (current_time + duration_or_expiry).isoformat() +        _, expiry = unpack_duration(duration_or_expiry, current_time) +        payload["expires_at"] = expiry.isoformat()      # Try to apply the infraction. If it fails because the user doesn't exist, try to add it.      for should_post_user in (True, False): diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index a7d7a844a..6ef382119 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -2,6 +2,7 @@ import re  import textwrap  import typing as t +import arrow  import discord  from discord.ext import commands  from discord.ext.commands import Context @@ -9,7 +10,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser +from bot.converters import DurationOrExpiry, Infraction, MemberOrUser, Snowflake, UnambiguousUser  from bot.decorators import ensure_future_timestamp  from bot.errors import InvalidInfraction  from bot.exts.moderation.infraction import _utils @@ -20,6 +21,7 @@ from bot.pagination import LinePaginator  from bot.utils import messages, time  from bot.utils.channel import is_mod_channel  from bot.utils.members import get_or_fetch_member +from bot.utils.time import unpack_duration  log = get_logger(__name__) @@ -89,7 +91,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -129,7 +131,7 @@ class ModManagement(commands.Cog):          self,          ctx: Context,          infraction: Infraction, -        duration: t.Union[Expiry, t.Literal["p", "permanent"], None], +        duration: t.Union[DurationOrExpiry, t.Literal["p", "permanent"], None],          *,          reason: str = None      ) -> None: @@ -172,8 +174,11 @@ class ModManagement(commands.Cog):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: -            request_data['expires_at'] = duration.isoformat() -            expiry = time.format_with_duration(duration) +            origin, expiry = unpack_duration(duration) +            # Update `last_applied` if expiry changes. +            request_data['last_applied'] = origin.isoformat() +            request_data['expires_at'] = expiry.isoformat() +            expiry = time.format_with_duration(expiry, origin)              confirm_messages.append(f"set to expire on {expiry}")          else:              confirm_messages.append("expiry unchanged") @@ -380,7 +385,10 @@ class ModManagement(commands.Cog):          user = infraction["user"]          expires_at = infraction["expires_at"]          inserted_at = infraction["inserted_at"] +        last_applied = infraction["last_applied"]          created = time.discord_timestamp(inserted_at) +        applied = time.discord_timestamp(last_applied) +        duration_edited = arrow.get(last_applied) > arrow.get(inserted_at)          dm_sent = infraction["dm_sent"]          # Format the user string. @@ -400,7 +408,11 @@ class ModManagement(commands.Cog):          if expires_at is None:              duration = "*Permanent*"          else: -            duration = time.humanize_delta(inserted_at, expires_at) +            duration = time.humanize_delta(last_applied, expires_at) + +        # Notice if infraction expiry was edited. +        if duration_edited: +            duration += f" (edited {applied})"          # Format `dm_sent`          if dm_sent is None: diff --git a/bot/utils/time.py b/bot/utils/time.py index 104ea026d..820ac2929 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,12 +1,18 @@ +from __future__ import annotations +  import datetime  import re +from copy import copy  from enum import Enum  from time import struct_time -from typing import Literal, Optional, Union, overload +from typing import Literal, Optional, TYPE_CHECKING, Union, overload  import arrow  from dateutil.relativedelta import relativedelta +if TYPE_CHECKING: +    from bot.converters import DurationOrExpiry +  _DURATION_REGEX = re.compile(      r"((?P<years>\d+?) ?(years|year|Y|y) ?)?"      r"((?P<months>\d+?) ?(months|month|m) ?)?" @@ -194,12 +200,8 @@ def humanize_delta(      elif len(args) <= 2:          end = arrow.get(args[0])          start = arrow.get(args[1]) if len(args) == 2 else arrow.utcnow() +        delta = round_delta(relativedelta(end.datetime, start.datetime)) -        # Round microseconds -        end = round_datetime(end.datetime) -        start = round_datetime(start.datetime) - -        delta = relativedelta(end, start)          if absolute:              delta = abs(delta)      else: @@ -332,12 +334,35 @@ def until_expiration(expiry: Optional[Timestamp]) -> str:      return format_relative(expiry) -def round_datetime(dt: datetime.datetime) -> datetime.datetime: +def unpack_duration( +        duration_or_expiry: DurationOrExpiry, +        origin: Optional[Union[datetime.datetime, arrow.Arrow]] = None +) -> tuple[datetime.datetime, datetime.datetime]: +    """ +    Unpacks a DurationOrExpiry into a tuple of (origin, expiry). + +    The `origin` defaults to the current UTC time at function call. +    """ +    if origin is None: +        origin = datetime.datetime.now(tz=datetime.timezone.utc) + +    if isinstance(origin, arrow.Arrow): +        origin = origin.datetime + +    if isinstance(duration_or_expiry, relativedelta): +        return origin, origin + duration_or_expiry +    else: +        return origin, duration_or_expiry + + +def round_delta(delta: relativedelta) -> relativedelta:      """ -    Round a datetime object to the nearest second. +    Rounds `delta` to the nearest second. -    Resulting datetime objects will have microsecond values of 0, useful for delta comparisons. +    Returns a copy with microsecond values of 0.      """ -    if dt.microsecond >= 500000: -        dt += datetime.timedelta(seconds=1) -    return dt.replace(microsecond=0) +    delta = copy(delta) +    if delta.microseconds >= 500000: +        delta += relativedelta(seconds=1) +    delta.microseconds = 0 +    return delta | 
