diff options
| -rw-r--r-- | bot/exts/moderation/clean.py | 230 |
1 files changed, 156 insertions, 74 deletions
diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py index 6fb33c692..bf018e8aa 100644 --- a/bot/exts/moderation/clean.py +++ b/bot/exts/moderation/clean.py @@ -8,7 +8,7 @@ from itertools import islice from typing import Any, Callable, DefaultDict, Iterable, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from discord import Colour, Embed, Message, NotFound, TextChannel, User, errors -from discord.ext.commands import Cog, Context, Converter, group, has_any_role +from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role from discord.ext.commands.converter import TextChannelConverter from discord.ext.commands.errors import BadArgument, MaxConcurrencyReached @@ -22,6 +22,8 @@ from bot.utils.channel import is_mod_channel log = logging.getLogger(__name__) +DEFAULT_TRAVERSE = 10 + # Type alias for checks Predicate = Callable[[Message], bool] @@ -40,8 +42,17 @@ class CleanChannels(Converter): return [await self._channel_converter.convert(ctx, channel) for channel in argument.split()] +class Regex(Converter): + """A converter that takes a string in the form r'.+' and strips the 'r' prefix and the single quotes.""" + + async def convert(self, ctx: Context, argument: str) -> str: + """Strips the 'r' prefix and the enclosing single quotes from the string.""" + return re.match(r"r'(.+?)'", argument).group(1) + + if TYPE_CHECKING: CleanChannels = Union[Literal["*"], list[TextChannel]] # noqa: F811 + Regex = str # noqa: F811 class Clean(Cog): @@ -71,10 +82,9 @@ class Clean(Cog): traverse: int, channels: CleanChannels, bots_only: bool, - user: User, + users: list[User], first_limit: CleanLimit, second_limit: CleanLimit, - use_cache: bool ) -> None: """Raise errors if an argument value or a combination of values is invalid.""" # Is this an acceptable amount of messages to traverse? @@ -89,10 +99,85 @@ class Clean(Cog): if first_limit.channel != second_limit.channel: raise BadArgument("Message limits are in different channels.") + if users and bots_only: + raise BadArgument("Marked as bots only, but users were specified.") + # This is an implementation error rather than user error. if second_limit and not first_limit: raise ValueError("Second limit specified without the first.") + @staticmethod + def _build_predicate( + bots_only: bool = False, + users: list[User] = None, + regex: Optional[str] = None, + first_limit: Optional[datetime] = None, + second_limit: Optional[datetime] = None, + ) -> Predicate: + """Return the predicate that decides whether to delete a given message.""" + def predicate_bots_only(message: Message) -> bool: + """Return True if the message was sent by a bot.""" + return message.author.bot + + def predicate_specific_users(message: Message) -> bool: + """Return True if the message was sent by the user provided in the _clean_messages call.""" + return message.author in users + + def predicate_regex(message: Message) -> bool: + """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" + content = [message.content] + + # Add the content for all embed attributes + for embed in message.embeds: + content.append(embed.title) + content.append(embed.description) + content.append(embed.footer.text) + content.append(embed.author.name) + for field in embed.fields: + content.append(field.name) + content.append(field.value) + + # Get rid of empty attributes and turn it into a string + content = [attr for attr in content if attr] + content = "\n".join(content) + + # Now let's see if there's a regex match + if not content: + return False + else: + return bool(re.search(regex.lower(), content.lower())) + + def predicate_range(message: Message) -> bool: + """Check if the message age is between the two limits.""" + return first_limit <= message.created_at <= second_limit + + def predicate_after(message: Message) -> bool: + """Check if the message is older than the first limit.""" + return message.created_at >= first_limit + + predicates = [] + # Set up the correct predicate + if bots_only: + predicates.append(predicate_bots_only) # Delete messages from bots + if users: + predicates.append(predicate_specific_users) # Delete messages from specific user + if regex: + predicates.append(predicate_regex) # Delete messages that match regex + # Add up to one of the following: + if second_limit: + predicates.append(predicate_range) # Delete messages in the specified age range + elif first_limit: + predicates.append(predicate_after) # Delete messages older than specific message + + if not predicates: + predicate = lambda m: True # Delete all messages # noqa: E731 + elif len(predicates) == 1: + predicate = predicates[0] + else: + predicate = lambda m: all(pred(m) for pred in predicates) # noqa: E731 + + return predicate + def _get_messages_from_cache(self, traverse: int, to_delete: Predicate) -> Tuple[DefaultDict, List[int]]: """Helper function for getting messages from the cache.""" message_mappings = defaultdict(list) @@ -239,78 +324,24 @@ class Clean(Cog): async def _clean_messages( self, - traverse: int, ctx: Context, + traverse: int, channels: CleanChannels, bots_only: bool = False, - user: User = None, + users: list[User] = None, regex: Optional[str] = None, first_limit: Optional[CleanLimit] = None, second_limit: Optional[CleanLimit] = None, use_cache: Optional[bool] = True ) -> None: """A helper function that does the actual message cleaning.""" - def predicate_bots_only(message: Message) -> bool: - """Return True if the message was sent by a bot.""" - return message.author.bot - - def predicate_specific_user(message: Message) -> bool: - """Return True if the message was sent by the user provided in the _clean_messages call.""" - return message.author == user - - def predicate_regex(message: Message) -> bool: - """Check if the regex provided in _clean_messages matches the message content or any embed attributes.""" - content = [message.content] - - # Add the content for all embed attributes - for embed in message.embeds: - content.append(embed.title) - content.append(embed.description) - content.append(embed.footer.text) - content.append(embed.author.name) - for field in embed.fields: - content.append(field.name) - content.append(field.value) - - # Get rid of empty attributes and turn it into a string - content = [attr for attr in content if attr] - content = "\n".join(content) - - # Now let's see if there's a regex match - if not content: - return False - else: - return bool(re.search(regex.lower(), content.lower())) - - def predicate_range(message: Message) -> bool: - """Check if the message age is between the two limits.""" - return first_limit <= message.created_at <= second_limit - - def predicate_after(message: Message) -> bool: - """Check if the message is older than the first limit.""" - return message.created_at >= first_limit - - self._validate_input(traverse, channels, bots_only, user, first_limit, second_limit, use_cache) + self._validate_input(traverse, channels, bots_only, users, first_limit, second_limit) # Are we already performing a clean? if self.cleaning: raise MaxConcurrencyReached("Please wait for the currently ongoing clean operation to complete.") self.cleaning = True - # Set up the correct predicate - if bots_only: - predicate = predicate_bots_only # Delete messages from bots - elif user: - predicate = predicate_specific_user # Delete messages from specific user - elif regex: - predicate = predicate_regex # Delete messages that match regex - elif second_limit: - predicate = predicate_range # Delete messages in the specified age range - elif first_limit: - predicate = predicate_after # Delete messages older than specific message - else: - predicate = lambda m: True # Delete all messages # noqa: E731 - # Default to using the invoking context's channel or the channel of the message limit(s). if not channels: # At this point second_limit is guaranteed to not exist, be a datetime, or a message in the same channel. @@ -328,6 +359,9 @@ class Clean(Cog): if first_limit and second_limit: first_limit, second_limit = sorted([first_limit, second_limit]) + # Needs to be called after standardizing the input. + predicate = self._build_predicate(bots_only, users, regex, first_limit, second_limit) + if not is_mod_channel(ctx.channel): # Delete the invocation first self.mod_log.ignore(Event.message_delete, ctx.message.id) @@ -369,9 +403,51 @@ class Clean(Cog): # region: Commands @group(invoke_without_command=True, name="clean", aliases=["clear", "purge"]) - async def clean_group(self, ctx: Context) -> None: - """Commands for cleaning messages in channels.""" - await ctx.send_help(ctx.command) + async def clean_group( + self, + ctx: Context, + traverse: Optional[int] = None, + users: Greedy[User] = None, + first_limit: Optional[CleanLimit] = None, + second_limit: Optional[CleanLimit] = None, + use_cache: Optional[bool] = None, + bots_only: Optional[bool] = False, + regex: Optional[Regex] = None, + *, + channels: Optional[CleanChannels] = None + ) -> None: + """ + Commands for cleaning messages in channels. + + If arguments are provided, will act as a master command from which all subcommands can be derived. + `traverse`: The number of messages to look at in each channel. + `users`: A series of user mentions, ID's, or names. + `first_limit` and `second_limit`: A message, a duration delta, or an ISO datetime. + If a message is provided, cleaning will happen in that channel, and channels cannot be provided. + If only one of them is provided, acts as `clean until`. If both are provided, acts as `clean between`. + `use_cache`: Whether to use the message cache. + If not provided, will default to False unless an asterisk is used for the channels. + `bots_only`: Whether to delete only bots. If specified, users cannot be specified. + `regex`: A regex pattern the message must contain to be deleted. + The pattern must be provided with an "r" prefix and enclosed in single quotes. + If the pattern contains spaces, it still needs to be enclosed in double quotes on top of that. + `channels`: A series of channels to delete in, or an asterisk to delete from all channels. + """ + if not any([traverse, users, first_limit, second_limit, regex]): + await ctx.send_help(ctx.command) + return + + if not traverse: + if first_limit: + traverse = CleanMessages.message_limit + else: + traverse = DEFAULT_TRAVERSE + if not use_cache: + use_cache = channels == "*" + + await self._clean_messages( + ctx, traverse, channels, bots_only, users, regex, first_limit, second_limit, use_cache + ) @clean_group.command(name="user", aliases=["users"]) async def clean_user( @@ -384,44 +460,50 @@ class Clean(Cog): channels: Optional[CleanChannels] = None ) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(traverse, ctx, user=user, channels=channels, use_cache=use_cache) + await self._clean_messages(ctx, traverse, users=[user], channels=channels, use_cache=use_cache) @clean_group.command(name="all", aliases=["everything"]) async def clean_all( self, ctx: Context, - traverse: Optional[int] = 10, + traverse: Optional[int] = DEFAULT_TRAVERSE, use_cache: Optional[bool] = True, *, channels: Optional[CleanChannels] = None ) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(traverse, ctx, channels=channels, use_cache=use_cache) + await self._clean_messages(ctx, traverse, channels=channels, use_cache=use_cache) @clean_group.command(name="bots", aliases=["bot"]) async def clean_bots( self, ctx: Context, - traverse: Optional[int] = 10, + traverse: Optional[int] = DEFAULT_TRAVERSE, use_cache: Optional[bool] = True, *, channels: Optional[CleanChannels] = None ) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(traverse, ctx, bots_only=True, channels=channels, use_cache=use_cache) + await self._clean_messages(ctx, traverse, bots_only=True, channels=channels, use_cache=use_cache) @clean_group.command(name="regex", aliases=["word", "expression", "pattern"]) async def clean_regex( self, ctx: Context, - regex: str, - traverse: Optional[int] = 10, + regex: Regex, + traverse: Optional[int] = DEFAULT_TRAVERSE, use_cache: Optional[bool] = True, *, channels: Optional[CleanChannels] = None ) -> None: - """Delete all messages that match a certain regex, stop cleaning after traversing `traverse` messages.""" - await self._clean_messages(traverse, ctx, regex=regex, channels=channels, use_cache=use_cache) + """ + Delete all messages that match a certain regex, stop cleaning after traversing `traverse` messages. + + The pattern must be provided with an "r" prefix and enclosed in single quotes. + If the pattern contains spaces, and still needs to be enclosed in double quotes on top of that. + For example: r'[0-9]+' + """ + await self._clean_messages(ctx, traverse, regex=regex, channels=channels, use_cache=use_cache) @clean_group.command(name="until") async def clean_until( @@ -437,8 +519,8 @@ class Clean(Cog): If a message is specified, `channel` cannot be specified. """ await self._clean_messages( - CleanMessages.message_limit, ctx, + CleanMessages.message_limit, channels=[channel] if channel else None, first_limit=until, ) @@ -461,8 +543,8 @@ class Clean(Cog): If a message is specified, `channel` cannot be specified. """ await self._clean_messages( - CleanMessages.message_limit, ctx, + CleanMessages.message_limit, channels=[channel] if channel else None, first_limit=first_limit, second_limit=second_limit, |