diff options
| author | 2021-08-29 02:10:43 +0300 | |
|---|---|---|
| committer | 2021-08-29 02:10:43 +0300 | |
| commit | 13308200ff62784832ba9f9084b69cd3a214b966 (patch) | |
| tree | d60f38cabe85e8ae3839622f98f4d081a99e7738 | |
| parent | Send message when no messages found (diff) | |
`until` and `between` overhaul
- The two subcommands can now accept a time delta and an ISO date time in addition to messages.
- The two limits are now exclusive. Meaning cleaning until a message will not delete that message.
- Added a separate predicate for the `until` case, as the combination of that command and cache usage would result in incorrect behavior.
Additionally, deleting from cache now correctly traverses only `traverse` messages, rather than trying to delete `traverse` messages.
| -rw-r--r-- | bot/converters.py | 19 | ||||
| -rw-r--r-- | bot/exts/moderation/clean.py | 145 |
2 files changed, 111 insertions, 53 deletions
diff --git a/bot/converters.py b/bot/converters.py index 0118cc48a..546f6e8f4 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -388,6 +388,24 @@ class Duration(DurationDelta): raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") +class Age(DurationDelta): + """Convert duration strings into UTC datetime.datetime objects.""" + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the past. + + The converter supports the same symbols for each unit of time as its parent class. + """ + delta = await super().convert(ctx, duration) + now = datetime.utcnow() + + try: + return now - delta + except (ValueError, OverflowError): + raise BadArgument(f"`{duration}` results in a datetime outside the supported range.") + + class OffTopicName(Converter): """A converter that ensures an added off-topic name is valid.""" @@ -554,6 +572,7 @@ if t.TYPE_CHECKING: SourceConverter = SourceType # noqa: F811 DurationDelta = relativedelta # noqa: F811 Duration = datetime # noqa: F811 + Age = datetime # noqa: F811 OffTopicName = str # noqa: F811 ISODateTime = datetime # noqa: F811 HushDurationConverter = int # noqa: F811 diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index 1d323fa0b..90f7f3e03 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -2,17 +2,20 @@ import logging import re import time from collections import defaultdict +from datetime import datetime +from itertools import islice from typing import Any, Callable, DefaultDict, Iterable, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from discord import Colour, Embed, Message, NotFound, TextChannel, User, errors from discord.ext.commands import Cog, Context, Converter, group, has_any_role from discord.ext.commands.converter import TextChannelConverter -from discord.ext.commands.errors import BadArgument, MaxConcurrencyReached, MissingRequiredArgument +from discord.ext.commands.errors import BadArgument, MaxConcurrencyReached from bot.bot import Bot from bot.constants import ( Channels, CleanMessages, Colours, Emojis, Event, Icons, MODERATION_ROLES ) +from bot.converters import Age, ISODateTime from bot.exts.moderation.modlog import ModLog from bot.utils.channel import is_mod_channel @@ -21,6 +24,8 @@ log = logging.getLogger(__name__) # Type alias for checks Predicate = Callable[[Message], bool] +CleanLimit = Union[Message, Age, ISODateTime] + class CleanChannels(Converter): """A converter that turns the given string to a list of channels to clean, or the literal `*` for all channels.""" @@ -66,46 +71,40 @@ class Clean(Cog): channels: CleanChannels, bots_only: bool, user: User, - until_message: Message, - after_message: Message, + first_limit: CleanLimit, + second_limit: CleanLimit, use_cache: bool ) -> None: """Raise errors if an argument value or a combination of values is invalid.""" # Is this an acceptable amount of messages to traverse? if traverse > CleanMessages.message_limit: - raise BadArgument(f"You cannot traverse more than {CleanMessages.message_limit} messages.") + raise BadArgument(f"Cannot traverse more than {CleanMessages.message_limit} messages.") - if after_message: - # Ensure that until_message is specified. - if not until_message: - raise MissingRequiredArgument("`until_message` must be specified if `after_message` is specified.") + if (isinstance(first_limit, Message) or isinstance(first_limit, Message)) and channels: + raise BadArgument("Both a message limit and channels specified.") - # Messages are not in same channel - if after_message.channel != until_message.channel: - raise BadArgument("You cannot do range clean across several channel.") + if isinstance(first_limit, Message) and isinstance(second_limit, Message): + # Messages are not in same channel. + if first_limit.channel != second_limit.channel: + raise BadArgument("Message limits are in different channels.") - # Ensure that after_message is younger than until_message - if after_message.created_at >= until_message.created_at: - raise BadArgument("`after` message must be younger than `until` message") + # This is an implementation error rather than user error. + if second_limit and not first_limit: + raise ValueError("Second limit specified without the first.") def _get_messages_from_cache(self, traverse: int, to_delete: Predicate) -> Tuple[DefaultDict, List[int]]: """Helper function for getting messages from the cache.""" message_mappings = defaultdict(list) message_ids = [] - for message in self.bot.cached_messages: + for message in islice(self.bot.cached_messages, traverse): if not self.cleaning: # Cleaning was canceled - return (message_mappings, message_ids) + return message_mappings, message_ids if to_delete(message): message_mappings[message.channel].append(message) message_ids.append(message.id) - if len(message_ids) == traverse: - # We traversed the requested amount of messages. - return message_mappings, message_ids - - # There are fewer messages in the cache than the number requested to traverse. return message_mappings, message_ids async def _get_messages_from_channels( @@ -113,27 +112,19 @@ class Clean(Cog): traverse: int, channels: Iterable[TextChannel], to_delete: Predicate, - until_message: Optional[Message] = None + before: Optional[datetime] = None, + after: Optional[datetime] = None ) -> tuple[defaultdict[Any, list], list]: message_mappings = defaultdict(list) message_ids = [] for channel in channels: - - async for message in channel.history(limit=traverse): + async for message in channel.history(limit=traverse, before=before, after=after): if not self.cleaning: - # Cleaning was canceled, return empty containers + # Cleaning was canceled, return empty containers. return defaultdict(list), [] - if until_message: - - # We could use ID's here however in case if the message we are looking for gets deleted, - # We won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # Means we have found the message until which we were supposed to be deleting. - break - if to_delete(message): message_mappings[message.channel].append(message) message_ids.append(message.id) @@ -253,8 +244,8 @@ class Clean(Cog): bots_only: bool = False, user: User = None, regex: Optional[str] = None, - until_message: Optional[Message] = None, - after_message: Optional[Message] = None, + first_limit: Optional[CleanLimit] = None, + second_limit: Optional[CleanLimit] = None, use_cache: Optional[bool] = True ) -> None: """A helper function that does the actual message cleaning.""" @@ -291,10 +282,14 @@ class Clean(Cog): return bool(re.search(regex.lower(), content.lower())) def predicate_range(message: Message) -> bool: - """Check if message is older than message provided in after_message but younger than until_message.""" - return after_message.created_at <= message.created_at <= until_message.created_at + """Check if the message age is between the two limits.""" + return first_limit <= message.created_at <= second_limit - self._validate_input(traverse, channels, bots_only, user, until_message, after_message, use_cache) + def predicate_after(message: Message) -> bool: + """Check if the message is older than the first limit.""" + return message.created_at >= first_limit + + self._validate_input(traverse, channels, bots_only, user, first_limit, second_limit, use_cache) # Are we already performing a clean? if self.cleaning: @@ -308,17 +303,31 @@ class Clean(Cog): predicate = predicate_specific_user # Delete messages from specific user elif regex: predicate = predicate_regex # Delete messages that match regex - elif after_message: - predicate = predicate_range # Delete messages older than specific message + elif second_limit: + predicate = predicate_range # Delete messages in the specified age range + elif first_limit: + predicate = predicate_after # Delete messages older than specific message else: predicate = lambda m: True # Delete all messages # noqa: E731 - # Default to using the invoking context's channel + # Default to using the invoking context's channel or the channel of the message limit(s). if not channels: - channels = [ctx.channel] + # At this point second_limit is guaranteed to not exist, be a datetime, or a message in the same channel. + if isinstance(first_limit, Message): + channels = [first_limit.channel] + elif isinstance(second_limit, Message): + channels = [second_limit.channel] + else: + channels = [ctx.channel] - if not is_mod_channel(ctx.channel): + if isinstance(first_limit, Message): + first_limit = first_limit.created_at + if isinstance(second_limit, Message): + second_limit = second_limit.created_at + if first_limit and second_limit: + first_limit, second_limit = sorted([first_limit, second_limit]) + if not is_mod_channel(ctx.channel): # Delete the invocation first self.mod_log.ignore(Event.message_delete, ctx.message.id) try: @@ -337,7 +346,8 @@ class Clean(Cog): traverse=traverse, channels=deletion_channels, to_delete=predicate, - until_message=until_message + before=second_limit, + after=first_limit # Remember first is the earlier datetime. ) if not self.cleaning: @@ -418,25 +428,54 @@ class Clean(Cog): @clean_group.command(name="until") @has_any_role(*MODERATION_ROLES) - async def clean_until(self, ctx: Context, message: Message) -> None: - """Delete all messages until certain message, stop cleaning after hitting the `message`.""" + async def clean_until( + self, + ctx: Context, + until: CleanLimit, + use_cache: Optional[bool] = True, + *, + channels: Optional[CleanChannels] = None) -> None: + """ + Delete all messages until a certain limit. + + A limit can be either a message, and ISO date-time string, or a time delta. + If a message is specified, `channels` cannot be specified. + """ await self._clean_messages( CleanMessages.message_limit, ctx, - channels=[message.channel], - until_message=message + channels=channels, + first_limit=until, + use_cache=use_cache ) @clean_group.command(name="between", aliases=["after-until", "from-to"]) @has_any_role(*MODERATION_ROLES) - async def clean_between(self, ctx: Context, after_message: Message, until_message: Message) -> None: - """Delete all messages within range of messages.""" + async def clean_between( + self, + ctx: Context, + first_limit: CleanLimit, + second_limit: CleanLimit, + use_cache: Optional[bool] = True, + *, + channels: Optional[CleanChannels] = None + ) -> None: + """ + Delete all messages within range. + + The range is specified through two limits. + A limit can be either a message, and ISO date-time string, or a time delta. + + If two messages are specified, they both must be in the same channel. + If a message is specified, `channels` cannot be specified. + """ await self._clean_messages( CleanMessages.message_limit, ctx, - channels=[until_message.channel], - until_message=until_message, - after_message=after_message, + channels=channels, + first_limit=first_limit, + second_limit=second_limit, + use_cache=use_cache ) @clean_group.command(name="stop", aliases=["cancel", "abort"]) |