diff options
| author | 2021-02-23 14:50:35 +0100 | |
|---|---|---|
| committer | 2021-02-23 14:50:35 +0100 | |
| commit | 3797474cabac3fae94a381c0e00998d563efdc5a (patch) | |
| tree | e5bf9aee90cdaa98d2b8ccc99b8665b105eb48fd | |
| parent | rename command `messages` to `until` (diff) | |
Introduce cache to cleaning as well as fix cancel
| -rw-r--r-- | bot/exts/utils/clean.py | 125 |
1 files changed, 89 insertions, 36 deletions
diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py index 7ee0287fd..6301ade04 100644 --- a/bot/exts/utils/clean.py +++ b/bot/exts/utils/clean.py @@ -2,7 +2,8 @@ import logging import random import re import time -from typing import Dict, Iterable, List, Optional +from collections import defaultdict +from typing import Callable, DefaultDict, Iterable, List, Optional from discord import Colour, Embed, Message, NotFound, TextChannel, User from discord.ext import commands @@ -16,6 +17,9 @@ from bot.exts.moderation.modlog import ModLog log = logging.getLogger(__name__) +# Type alias for checks +CheckHint = Callable[[Message], bool] + class Clean(Cog): """ @@ -39,18 +43,74 @@ class Clean(Cog): async def _delete_messages_individually(self, messages: List[Message]) -> None: for message in messages: + # Ensure that deletion was not canceled + if not self.cleaning: + return try: await message.delete() except NotFound: - # message doesn't exist or was already deleted + # Message doesn't exist or was already deleted continue + def _get_messages_from_cache(self, amount: int, check: CheckHint) -> List[DefaultDict, List[int]]: + """Helper function for getting messages from the cache.""" + message_mappings = defaultdict(lambda: []) + message_ids = [] + for message in self.bot.cached_messages: + if not self.cleaning: + # Cleaning was canceled + return (message_mappings, message_ids) + + if check(message): + message_mappings[message.channel].append(message) + message_ids.append(message.id) + + if len(message_ids) == amount: + # We've got the requested amount of messages + return message_mappings, message_ids + + # Amount exceeds amount of messages matching the check + return message_mappings, message_ids + + async def _get_messages_from_channels( + self, + amount: int, + channels: Iterable[TextChannel], + check: CheckHint, + until_message: Optional[Message] = None + ) -> DefaultDict: + message_mappings = defaultdict(lambda: []) + message_ids = [] + + for channel in channels: + + async for message in channel.history(amount=amount): + + if not self.cleaning: + # Cleaning was canceled + return (message_mappings, message_ids) + + if check(message): + message_mappings[message.channel].append(message) + message_ids.append(message.id) + + if until_message: + + # We could use ID's here however in case if the message we are looking for gets deleted, + # We won't have a way to figure that out thus checking for datetime should be more reliable + if message.created_at < until_message.created_at: + # Means we have found the message until which we were supposed to be deleting. + break + + return message_mappings, message_ids + async def _clean_messages( self, amount: int, ctx: Context, channels: Iterable[TextChannel], bots_only: bool = False, + use_cache: bool = False, user: User = None, regex: Optional[str] = None, until_message: Optional[Message] = None, @@ -126,41 +186,21 @@ class Clean(Cog): self.mod_log.ignore(Event.message_delete, ctx.message.id) await ctx.message.delete() - # we need Channel to Message mapping for easier deletion via TextChannel.delete_messages - message_mappings: Dict[TextChannel, List[Message]] = {} - message_ids = [] self.cleaning = True - # Find the IDs of the messages to delete. IDs are needed in order to ignore mod log events. - for channel in channels: - - messages = [] - - async for message in channel.history(limit=amount): - - # If at any point the cancel command is invoked, we should stop. - if not self.cleaning: - return - - # If the message passes predicate, let's save it. - if predicate(message): - messages.append(message) - message_ids.append(message) - - # if we are looking for specific message - if until_message: - - # we could use ID's here however in case if the message we are looking for gets deleted, - # we won't have a way to figure that out thus checking for datetime should be more reliable - if message.created_at < until_message.created_at: - # means we have found the message until which we were supposed to be deleting. - break - - if len(messages) > 0: - # we don't want to create mappings of TextChannel to empty list - message_mappings[channel] = messages + if use_cache: + message_mappings, message_ids = self._get_messages_from_cache(amount, predicate) + else: + message_mappings, message_ids = await self._get_messages_from_channels( + amount=amount, + channels=channels, + check=predicate, + until_message=until_message + ) - self.cleaning = False + if not self.cleaning: + # Means that the cleaning was canceled + return # Now let's delete the actual messages with purge. self.mod_log.ignore(Event.message_delete, *message_ids) @@ -174,9 +214,16 @@ class Clean(Cog): for current_index, message in enumerate(messages): + if not self.cleaning: + # Means that the cleaning was canceled + return + if message.id < minimum_time: # further messages are too old to be deleted in bulk await self._delete_messages_individually(messages[current_index:]) + if not self.cleaning: + # Means that deletion was canceled while deleting the individual messages + return break to_delete.append(message) @@ -241,7 +288,10 @@ class Clean(Cog): channels: commands.Greedy[TextChannel] = None ) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user, channels=channels) + use_cache = True + if channels: + use_cache = False + await self._clean_messages(amount, ctx, user=user, channels=channels, use_cache=use_cache) @clean_group.command(name="all", aliases=["everything"]) @has_any_role(*MODERATION_ROLES) @@ -275,7 +325,10 @@ class Clean(Cog): channels: commands.Greedy[TextChannel] = None ) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex, channels=channels) + use_cache = True + if channels: + use_cache = False + await self._clean_messages(amount, ctx, regex=regex, channels=channels, use_cache=use_cache) @clean_group.command(name="until") @has_any_role(*MODERATION_ROLES) |