diff options
-rw-r--r-- | bot/exts/moderation/clean.py | 207 |
1 files changed, 102 insertions, 105 deletions
diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index 826265aa3..e61ef7880 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -1,12 +1,11 @@ import contextlib -import logging import re import time from collections import defaultdict from contextlib import suppress from datetime import datetime -from itertools import islice -from typing import Any, Callable, Iterable, Literal, Optional, TYPE_CHECKING, Union +from itertools import takewhile +from typing import Callable, Iterable, Literal, Optional, TYPE_CHECKING, Union from discord import Colour, Message, NotFound, TextChannel, User, errors from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role @@ -17,12 +16,11 @@ from bot.bot import Bot from bot.constants import Channels, CleanMessages, Colours, Emojis, Event, Icons, MODERATION_ROLES from bot.converters import Age, ISODateTime from bot.exts.moderation.modlog import ModLog +from bot.log import get_logger from bot.utils.channel import is_mod_channel -log = logging.getLogger(__name__) +log = get_logger(__name__) -# Default number of messages to look at in each channel. -DEFAULT_TRAVERSE = 10 # Number of seconds before command invocations and responses are deleted in non-moderation channels. MESSAGE_DELETE_DELAY = 5 @@ -33,12 +31,12 @@ CleanLimit = Union[Message, Age, ISODateTime] class CleanChannels(Converter): - """A converter that turns the given string to a list of channels to clean, or the literal `*` for all channels.""" + """A converter to turn the string into a list of channels to clean, or the literal `*` for all public channels.""" _channel_converter = TextChannelConverter() async def convert(self, ctx: Context, argument: str) -> Union[Literal["*"], list[TextChannel]]: - """Converts a string to a list of channels to clean, or the literal `*` for all channels.""" + """Converts a string to a list of channels to clean, or the literal `*` for all public channels.""" if argument == "*": return "*" return [await self._channel_converter.convert(ctx, channel) for channel in argument.split()] @@ -87,7 +85,6 @@ class Clean(Cog): @staticmethod def _validate_input( - traverse: int, channels: Optional[CleanChannels], bots_only: bool, users: Optional[list[User]], @@ -95,9 +92,9 @@ class Clean(Cog): second_limit: Optional[CleanLimit], ) -> None: """Raise errors if an argument value or a combination of values is invalid.""" - # Is this an acceptable amount of messages to traverse? - if traverse > CleanMessages.message_limit: - raise BadArgument(f"Cannot traverse more than {CleanMessages.message_limit} messages.") + if first_limit is None: + # This is an optional argument for the sake of the master command, but it's actually required. + raise BadArgument("Missing cleaning limit.") if (isinstance(first_limit, Message) or isinstance(second_limit, Message)) and channels: raise BadArgument("Both a message limit and channels specified.") @@ -110,10 +107,6 @@ class Clean(Cog): if users and bots_only: raise BadArgument("Marked as bots only, but users were specified.") - # This is an implementation error rather than user error. - if second_limit and not first_limit: - raise ValueError("Second limit specified without the first.") - @staticmethod async def _send_expiring_message(ctx: Context, content: str) -> None: """Send `content` to the context channel. Automatically delete if it's not a mod channel.""" @@ -121,12 +114,39 @@ class Clean(Cog): await ctx.send(content, delete_after=delete_after) @staticmethod + def _channels_set( + channels: CleanChannels, ctx: Context, first_limit: CleanLimit, second_limit: CleanLimit + ) -> set[TextChannel]: + """Standardize the input `channels` argument to a usable set of text channels.""" + # Default to using the invoking context's channel or the channel of the message limit(s). + if not channels: + # Input was validated - if first_limit is a message, second_limit won't point at a different channel. + if isinstance(first_limit, Message): + channels = {first_limit.channel} + elif isinstance(second_limit, Message): + channels = {second_limit.channel} + else: + channels = {ctx.channel} + else: + if channels == "*": + channels = { + channel for channel in ctx.guild.channels + if isinstance(channel, TextChannel) + # Assume that non-public channels are not needed to optimize for speed. + and channel.permissions_for(ctx.guild.default_role).view_channel + } + else: + channels = set(channels) + + return channels + + @staticmethod def _build_predicate( + first_limit: datetime, + second_limit: Optional[datetime] = None, bots_only: bool = False, users: Optional[list[User]] = None, regex: Optional[re.Pattern] = None, - first_limit: Optional[datetime] = None, - second_limit: Optional[datetime] = None, ) -> Predicate: """Return the predicate that decides whether to delete a given message.""" def predicate_bots_only(message: Message) -> bool: @@ -167,20 +187,18 @@ class Clean(Cog): predicates = [] # Set up the correct predicate + if second_limit: + predicates.append(predicate_range) # Delete messages in the specified age range + else: + predicates.append(predicate_after) # Delete messages older than the specified age + if bots_only: predicates.append(predicate_bots_only) # Delete messages from bots if users: predicates.append(predicate_specific_users) # Delete messages from specific user if regex: predicates.append(predicate_regex) # Delete messages that match regex - # Add up to one of the following: - if second_limit: - predicates.append(predicate_range) # Delete messages in the specified age range - elif first_limit: - predicates.append(predicate_after) # Delete messages older than specific message - if not predicates: - return lambda m: True if len(predicates) == 1: return predicates[0] return lambda m: all(pred(m) for pred in predicates) @@ -195,16 +213,25 @@ class Clean(Cog): # Invocation message has already been deleted log.info("Tried to delete invocation message, but it was already deleted.") - def _get_messages_from_cache(self, traverse: int, to_delete: Predicate) -> tuple[defaultdict[Any, list], list[int]]: + def _use_cache(self, limit: datetime) -> bool: + """Tell whether all messages to be cleaned can be found in the cache.""" + return self.bot.cached_messages[0].created_at <= limit + + def _get_messages_from_cache( + self, + channels: set[TextChannel], + to_delete: Predicate, + lower_limit: datetime + ) -> tuple[defaultdict[TextChannel, list], list[int]]: """Helper function for getting messages from the cache.""" message_mappings = defaultdict(list) message_ids = [] - for message in islice(self.bot.cached_messages, traverse): + for message in takewhile(lambda m: m.created_at > lower_limit, reversed(self.bot.cached_messages)): if not self.cleaning: # Cleaning was canceled return message_mappings, message_ids - if to_delete(message): + if message.channel in channels and to_delete(message): message_mappings[message.channel].append(message) message_ids.append(message.id) @@ -212,17 +239,16 @@ class Clean(Cog): async def _get_messages_from_channels( self, - traverse: int, channels: Iterable[TextChannel], to_delete: Predicate, - before: Optional[datetime] = None, + before: datetime, after: Optional[datetime] = None - ) -> tuple[defaultdict[Any, list], list]: + ) -> tuple[defaultdict[TextChannel, list], list]: message_mappings = defaultdict(list) message_ids = [] for channel in channels: - async for message in channel.history(limit=traverse, before=before, after=after): + async for message in channel.history(limit=CleanMessages.message_limit, before=before, after=after): if not self.cleaning: # Cleaning was canceled, return empty containers. @@ -318,7 +344,7 @@ class Clean(Cog): # Build the embed and send it if channels == "*": - target_channels = "all channels" + target_channels = "all public channels" else: target_channels = ", ".join(channel.mention for channel in channels) @@ -343,17 +369,15 @@ class Clean(Cog): async def _clean_messages( self, ctx: Context, - traverse: int, channels: Optional[CleanChannels], bots_only: bool = False, users: Optional[list[User]] = None, regex: Optional[re.Pattern] = None, first_limit: Optional[CleanLimit] = None, second_limit: Optional[CleanLimit] = None, - use_cache: Optional[bool] = True ) -> None: """A helper function that does the actual message cleaning.""" - self._validate_input(traverse, channels, bots_only, users, first_limit, second_limit) + self._validate_input(channels, bots_only, users, first_limit, second_limit) # Are we already performing a clean? if self.cleaning: @@ -363,15 +387,7 @@ class Clean(Cog): return self.cleaning = True - # Default to using the invoking context's channel or the channel of the message limit(s). - if not channels: - # Input was validated - if first_limit is a message, second_limit won't point at a different channel. - if isinstance(first_limit, Message): - channels = [first_limit.channel] - elif isinstance(second_limit, Message): - channels = [second_limit.channel] - else: - channels = [ctx.channel] + deletion_channels = self._channels_set(channels, ctx, first_limit, second_limit) if isinstance(first_limit, Message): first_limit = first_limit.created_at @@ -381,19 +397,19 @@ class Clean(Cog): first_limit, second_limit = sorted([first_limit, second_limit]) # Needs to be called after standardizing the input. - predicate = self._build_predicate(bots_only, users, regex, first_limit, second_limit) + predicate = self._build_predicate(first_limit, second_limit, bots_only, users, regex) # Delete the invocation first await self._delete_invocation(ctx) - if channels == "*" and use_cache: - message_mappings, message_ids = self._get_messages_from_cache(traverse=traverse, to_delete=predicate) + if self._use_cache(first_limit): + log.trace(f"Messages for cleaning by {ctx.author.id} will be searched in the cache.") + message_mappings, message_ids = self._get_messages_from_cache( + channels=deletion_channels, to_delete=predicate, lower_limit=first_limit + ) else: - deletion_channels = channels - if channels == "*": - deletion_channels = [channel for channel in ctx.guild.channels if isinstance(channel, TextChannel)] + log.trace(f"Messages for cleaning by {ctx.author.id} will be searched in channel histories.") message_mappings, message_ids = await self._get_messages_from_channels( - traverse=traverse, channels=deletion_channels, to_delete=predicate, before=second_limit, @@ -409,6 +425,8 @@ class Clean(Cog): deleted_messages = await self._delete_found(message_mappings) self.cleaning = False + if not channels: + channels = deletion_channels logged = await self._modlog_cleaned_messages(deleted_messages, channels, ctx) if logged and is_mod_channel(ctx.channel): @@ -422,12 +440,10 @@ class Clean(Cog): self, ctx: Context, users: Greedy[User] = None, - traverse: Optional[int] = None, first_limit: Optional[CleanLimit] = None, second_limit: Optional[CleanLimit] = None, - use_cache: Optional[bool] = None, - bots_only: Optional[bool] = False, regex: Optional[Regex] = None, + bots_only: Optional[bool] = False, *, channels: CleanChannels = None # "Optional" with discord.py silently ignores incorrect input. ) -> None: @@ -437,91 +453,74 @@ class Clean(Cog): If arguments are provided, will act as a master command from which all subcommands can be derived. \u2003• `users`: A series of user mentions, ID's, or names. - \u2003• `traverse`: The number of messages to look at in each channel. If using the cache, will look at the - first `traverse` messages in the cache. \u2003• `first_limit` and `second_limit`: A message, a duration delta, or an ISO datetime. + At least one limit is required. If a message is provided, cleaning will happen in that channel, and channels cannot be provided. - If a limit is provided, multiple channels cannot be provided. If only one of them is provided, acts as `clean until`. If both are provided, acts as `clean between`. - \u2003• `use_cache`: Whether to use the message cache. - If not provided, will default to False unless an asterisk is used for the channels. - \u2003• `bots_only`: Whether to delete only bots. If specified, users cannot be specified. \u2003• `regex`: A regex pattern the message must contain to be deleted. The pattern must be provided enclosed in backticks. If the pattern contains spaces, it still needs to be enclosed in double quotes on top of that. - \u2003• `channels`: A series of channels to delete in, or an asterisk to delete from all channels. + \u2003• `bots_only`: Whether to delete only bots. If specified, users cannot be specified. + \u2003• `channels`: A series of channels to delete in, or an asterisk to delete from all public channels. """ - if not any([traverse, users, first_limit, second_limit, regex, channels]): + if not any([users, first_limit, second_limit, regex, channels]): await ctx.send_help(ctx.command) return - if not traverse: - if first_limit: - traverse = CleanMessages.message_limit - else: - traverse = DEFAULT_TRAVERSE - if use_cache is None: - use_cache = channels == "*" - - await self._clean_messages( - ctx, traverse, channels, bots_only, users, regex, first_limit, second_limit, use_cache - ) + await self._clean_messages(ctx, channels, bots_only, users, regex, first_limit, second_limit) @clean_group.command(name="user", aliases=["users"]) async def clean_user( self, ctx: Context, user: User, - traverse: Optional[int] = DEFAULT_TRAVERSE, - use_cache: Optional[bool] = True, + message_or_time: CleanLimit, *, channels: CleanChannels = None ) -> None: - """Delete messages posted by the provided user, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(ctx, traverse, users=[user], channels=channels, use_cache=use_cache) + """ + Delete messages posted by the provided user, stop cleaning after reaching `message_or_time`. - @clean_group.command(name="all", aliases=["everything"]) - async def clean_all( - self, - ctx: Context, - traverse: Optional[int] = DEFAULT_TRAVERSE, - use_cache: Optional[bool] = True, - *, - channels: CleanChannels = None - ) -> None: - """Delete all messages, regardless of poster, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(ctx, traverse, channels=channels, use_cache=use_cache) + `message_or_time` can be either a message to stop at (exclusive), a timedelta for max message age, or an ISO + datetime. + + If a message is specified, `channels` cannot be specified. + """ + await self._clean_messages(ctx, users=[user], channels=channels, first_limit=message_or_time) @clean_group.command(name="bots", aliases=["bot"]) - async def clean_bots( - self, - ctx: Context, - traverse: Optional[int] = DEFAULT_TRAVERSE, - use_cache: Optional[bool] = True, - *, - channels: CleanChannels = None - ) -> None: - """Delete all messages posted by a bot, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(ctx, traverse, bots_only=True, channels=channels, use_cache=use_cache) + async def clean_bots(self, ctx: Context, message_or_time: CleanLimit, *, channels: CleanChannels = None) -> None: + """ + Delete all messages posted by a bot, stop cleaning after traversing `traverse` messages. + + `message_or_time` can be either a message to stop at (exclusive), a timedelta for max message age, or an ISO + datetime. + + If a message is specified, `channels` cannot be specified. + """ + await self._clean_messages(ctx, bots_only=True, channels=channels, first_limit=message_or_time) @clean_group.command(name="regex", aliases=["word", "expression", "pattern"]) async def clean_regex( self, ctx: Context, regex: Regex, - traverse: Optional[int] = DEFAULT_TRAVERSE, - use_cache: Optional[bool] = True, + message_or_time: CleanLimit, *, channels: CleanChannels = None ) -> None: """ - Delete all messages that match a certain regex, stop cleaning after traversing `traverse` messages. + Delete all messages that match a certain regex, stop cleaning after reaching `message_or_time`. + + `message_or_time` can be either a message to stop at (exclusive), a timedelta for max message age, or an ISO + datetime. + If a message is specified, `channels` cannot be specified. The pattern must be provided enclosed in backticks. If the pattern contains spaces, it still needs to be enclosed in double quotes on top of that. For example: `[0-9]` """ - await self._clean_messages(ctx, traverse, regex=regex, channels=channels, use_cache=use_cache) + await self._clean_messages(ctx, regex=regex, channels=channels, first_limit=message_or_time) @clean_group.command(name="until") async def clean_until( @@ -538,7 +537,6 @@ class Clean(Cog): """ await self._clean_messages( ctx, - CleanMessages.message_limit, channels=[channel] if channel else None, first_limit=until, ) @@ -562,7 +560,6 @@ class Clean(Cog): """ await self._clean_messages( ctx, - CleanMessages.message_limit, channels=[channel] if channel else None, first_limit=first_limit, second_limit=second_limit, |