aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/moderation/clean.py230
1 files changed, 156 insertions, 74 deletions
diff --git a/bot/exts/moderation/clean.py b/bot/exts/moderation/clean.py
index 6fb33c692..bf018e8aa 100644
--- a/bot/exts/moderation/clean.py
+++ b/bot/exts/moderation/clean.py
@@ -8,7 +8,7 @@ from itertools import islice
from typing import Any, Callable, DefaultDict, Iterable, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
from discord import Colour, Embed, Message, NotFound, TextChannel, User, errors
-from discord.ext.commands import Cog, Context, Converter, group, has_any_role
+from discord.ext.commands import Cog, Context, Converter, Greedy, group, has_any_role
from discord.ext.commands.converter import TextChannelConverter
from discord.ext.commands.errors import BadArgument, MaxConcurrencyReached
@@ -22,6 +22,8 @@ from bot.utils.channel import is_mod_channel
log = logging.getLogger(__name__)
+DEFAULT_TRAVERSE = 10
+
# Type alias for checks
Predicate = Callable[[Message], bool]
@@ -40,8 +42,17 @@ class CleanChannels(Converter):
return [await self._channel_converter.convert(ctx, channel) for channel in argument.split()]
+class Regex(Converter):
+ """A converter that takes a string in the form r'.+' and strips the 'r' prefix and the single quotes."""
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Strips the 'r' prefix and the enclosing single quotes from the string."""
+ return re.match(r"r'(.+?)'", argument).group(1)
+
+
if TYPE_CHECKING:
CleanChannels = Union[Literal["*"], list[TextChannel]] # noqa: F811
+ Regex = str # noqa: F811
class Clean(Cog):
@@ -71,10 +82,9 @@ class Clean(Cog):
traverse: int,
channels: CleanChannels,
bots_only: bool,
- user: User,
+ users: list[User],
first_limit: CleanLimit,
second_limit: CleanLimit,
- use_cache: bool
) -> None:
"""Raise errors if an argument value or a combination of values is invalid."""
# Is this an acceptable amount of messages to traverse?
@@ -89,10 +99,85 @@ class Clean(Cog):
if first_limit.channel != second_limit.channel:
raise BadArgument("Message limits are in different channels.")
+ if users and bots_only:
+ raise BadArgument("Marked as bots only, but users were specified.")
+
# This is an implementation error rather than user error.
if second_limit and not first_limit:
raise ValueError("Second limit specified without the first.")
+ @staticmethod
+ def _build_predicate(
+ bots_only: bool = False,
+ users: list[User] = None,
+ regex: Optional[str] = None,
+ first_limit: Optional[datetime] = None,
+ second_limit: Optional[datetime] = None,
+ ) -> Predicate:
+ """Return the predicate that decides whether to delete a given message."""
+ def predicate_bots_only(message: Message) -> bool:
+ """Return True if the message was sent by a bot."""
+ return message.author.bot
+
+ def predicate_specific_users(message: Message) -> bool:
+ """Return True if the message was sent by the user provided in the _clean_messages call."""
+ return message.author in users
+
+ def predicate_regex(message: Message) -> bool:
+ """Check if the regex provided in _clean_messages matches the message content or any embed attributes."""
+ content = [message.content]
+
+ # Add the content for all embed attributes
+ for embed in message.embeds:
+ content.append(embed.title)
+ content.append(embed.description)
+ content.append(embed.footer.text)
+ content.append(embed.author.name)
+ for field in embed.fields:
+ content.append(field.name)
+ content.append(field.value)
+
+ # Get rid of empty attributes and turn it into a string
+ content = [attr for attr in content if attr]
+ content = "\n".join(content)
+
+ # Now let's see if there's a regex match
+ if not content:
+ return False
+ else:
+ return bool(re.search(regex.lower(), content.lower()))
+
+ def predicate_range(message: Message) -> bool:
+ """Check if the message age is between the two limits."""
+ return first_limit <= message.created_at <= second_limit
+
+ def predicate_after(message: Message) -> bool:
+ """Check if the message is older than the first limit."""
+ return message.created_at >= first_limit
+
+ predicates = []
+ # Set up the correct predicate
+ if bots_only:
+ predicates.append(predicate_bots_only) # Delete messages from bots
+ if users:
+ predicates.append(predicate_specific_users) # Delete messages from specific user
+ if regex:
+ predicates.append(predicate_regex) # Delete messages that match regex
+ # Add up to one of the following:
+ if second_limit:
+ predicates.append(predicate_range) # Delete messages in the specified age range
+ elif first_limit:
+ predicates.append(predicate_after) # Delete messages older than specific message
+
+ if not predicates:
+ predicate = lambda m: True # Delete all messages # noqa: E731
+ elif len(predicates) == 1:
+ predicate = predicates[0]
+ else:
+ predicate = lambda m: all(pred(m) for pred in predicates) # noqa: E731
+
+ return predicate
+
def _get_messages_from_cache(self, traverse: int, to_delete: Predicate) -> Tuple[DefaultDict, List[int]]:
"""Helper function for getting messages from the cache."""
message_mappings = defaultdict(list)
@@ -239,78 +324,24 @@ class Clean(Cog):
async def _clean_messages(
self,
- traverse: int,
ctx: Context,
+ traverse: int,
channels: CleanChannels,
bots_only: bool = False,
- user: User = None,
+ users: list[User] = None,
regex: Optional[str] = None,
first_limit: Optional[CleanLimit] = None,
second_limit: Optional[CleanLimit] = None,
use_cache: Optional[bool] = True
) -> None:
"""A helper function that does the actual message cleaning."""
- def predicate_bots_only(message: Message) -> bool:
- """Return True if the message was sent by a bot."""
- return message.author.bot
-
- def predicate_specific_user(message: Message) -> bool:
- """Return True if the message was sent by the user provided in the _clean_messages call."""
- return message.author == user
-
- def predicate_regex(message: Message) -> bool:
- """Check if the regex provided in _clean_messages matches the message content or any embed attributes."""
- content = [message.content]
-
- # Add the content for all embed attributes
- for embed in message.embeds:
- content.append(embed.title)
- content.append(embed.description)
- content.append(embed.footer.text)
- content.append(embed.author.name)
- for field in embed.fields:
- content.append(field.name)
- content.append(field.value)
-
- # Get rid of empty attributes and turn it into a string
- content = [attr for attr in content if attr]
- content = "\n".join(content)
-
- # Now let's see if there's a regex match
- if not content:
- return False
- else:
- return bool(re.search(regex.lower(), content.lower()))
-
- def predicate_range(message: Message) -> bool:
- """Check if the message age is between the two limits."""
- return first_limit <= message.created_at <= second_limit
-
- def predicate_after(message: Message) -> bool:
- """Check if the message is older than the first limit."""
- return message.created_at >= first_limit
-
- self._validate_input(traverse, channels, bots_only, user, first_limit, second_limit, use_cache)
+ self._validate_input(traverse, channels, bots_only, users, first_limit, second_limit)
# Are we already performing a clean?
if self.cleaning:
raise MaxConcurrencyReached("Please wait for the currently ongoing clean operation to complete.")
self.cleaning = True
- # Set up the correct predicate
- if bots_only:
- predicate = predicate_bots_only # Delete messages from bots
- elif user:
- predicate = predicate_specific_user # Delete messages from specific user
- elif regex:
- predicate = predicate_regex # Delete messages that match regex
- elif second_limit:
- predicate = predicate_range # Delete messages in the specified age range
- elif first_limit:
- predicate = predicate_after # Delete messages older than specific message
- else:
- predicate = lambda m: True # Delete all messages # noqa: E731
-
# Default to using the invoking context's channel or the channel of the message limit(s).
if not channels:
# At this point second_limit is guaranteed to not exist, be a datetime, or a message in the same channel.
@@ -328,6 +359,9 @@ class Clean(Cog):
if first_limit and second_limit:
first_limit, second_limit = sorted([first_limit, second_limit])
+ # Needs to be called after standardizing the input.
+ predicate = self._build_predicate(bots_only, users, regex, first_limit, second_limit)
+
if not is_mod_channel(ctx.channel):
# Delete the invocation first
self.mod_log.ignore(Event.message_delete, ctx.message.id)
@@ -369,9 +403,51 @@ class Clean(Cog):
# region: Commands
@group(invoke_without_command=True, name="clean", aliases=["clear", "purge"])
- async def clean_group(self, ctx: Context) -> None:
- """Commands for cleaning messages in channels."""
- await ctx.send_help(ctx.command)
+ async def clean_group(
+ self,
+ ctx: Context,
+ traverse: Optional[int] = None,
+ users: Greedy[User] = None,
+ first_limit: Optional[CleanLimit] = None,
+ second_limit: Optional[CleanLimit] = None,
+ use_cache: Optional[bool] = None,
+ bots_only: Optional[bool] = False,
+ regex: Optional[Regex] = None,
+ *,
+ channels: Optional[CleanChannels] = None
+ ) -> None:
+ """
+ Commands for cleaning messages in channels.
+
+ If arguments are provided, will act as a master command from which all subcommands can be derived.
+ `traverse`: The number of messages to look at in each channel.
+ `users`: A series of user mentions, ID's, or names.
+ `first_limit` and `second_limit`: A message, a duration delta, or an ISO datetime.
+ If a message is provided, cleaning will happen in that channel, and channels cannot be provided.
+ If only one of them is provided, acts as `clean until`. If both are provided, acts as `clean between`.
+ `use_cache`: Whether to use the message cache.
+ If not provided, will default to False unless an asterisk is used for the channels.
+ `bots_only`: Whether to delete only bots. If specified, users cannot be specified.
+ `regex`: A regex pattern the message must contain to be deleted.
+ The pattern must be provided with an "r" prefix and enclosed in single quotes.
+ If the pattern contains spaces, it still needs to be enclosed in double quotes on top of that.
+ `channels`: A series of channels to delete in, or an asterisk to delete from all channels.
+ """
+ if not any([traverse, users, first_limit, second_limit, regex]):
+ await ctx.send_help(ctx.command)
+ return
+
+ if not traverse:
+ if first_limit:
+ traverse = CleanMessages.message_limit
+ else:
+ traverse = DEFAULT_TRAVERSE
+ if not use_cache:
+ use_cache = channels == "*"
+
+ await self._clean_messages(
+ ctx, traverse, channels, bots_only, users, regex, first_limit, second_limit, use_cache
+ )
@clean_group.command(name="user", aliases=["users"])
async def clean_user(
@@ -384,44 +460,50 @@ class Clean(Cog):
channels: Optional[CleanChannels] = None
) -> None:
"""Delete messages posted by the provided user, stop cleaning after traversing `traverse` messages."""
- await self._clean_messages(traverse, ctx, user=user, channels=channels, use_cache=use_cache)
+ await self._clean_messages(ctx, traverse, users=[user], channels=channels, use_cache=use_cache)
@clean_group.command(name="all", aliases=["everything"])
async def clean_all(
self,
ctx: Context,
- traverse: Optional[int] = 10,
+ traverse: Optional[int] = DEFAULT_TRAVERSE,
use_cache: Optional[bool] = True,
*,
channels: Optional[CleanChannels] = None
) -> None:
"""Delete all messages, regardless of poster, stop cleaning after traversing `traverse` messages."""
- await self._clean_messages(traverse, ctx, channels=channels, use_cache=use_cache)
+ await self._clean_messages(ctx, traverse, channels=channels, use_cache=use_cache)
@clean_group.command(name="bots", aliases=["bot"])
async def clean_bots(
self,
ctx: Context,
- traverse: Optional[int] = 10,
+ traverse: Optional[int] = DEFAULT_TRAVERSE,
use_cache: Optional[bool] = True,
*,
channels: Optional[CleanChannels] = None
) -> None:
"""Delete all messages posted by a bot, stop cleaning after traversing `traverse` messages."""
- await self._clean_messages(traverse, ctx, bots_only=True, channels=channels, use_cache=use_cache)
+ await self._clean_messages(ctx, traverse, bots_only=True, channels=channels, use_cache=use_cache)
@clean_group.command(name="regex", aliases=["word", "expression", "pattern"])
async def clean_regex(
self,
ctx: Context,
- regex: str,
- traverse: Optional[int] = 10,
+ regex: Regex,
+ traverse: Optional[int] = DEFAULT_TRAVERSE,
use_cache: Optional[bool] = True,
*,
channels: Optional[CleanChannels] = None
) -> None:
- """Delete all messages that match a certain regex, stop cleaning after traversing `traverse` messages."""
- await self._clean_messages(traverse, ctx, regex=regex, channels=channels, use_cache=use_cache)
+ """
+ Delete all messages that match a certain regex, stop cleaning after traversing `traverse` messages.
+
+ The pattern must be provided with an "r" prefix and enclosed in single quotes.
+ If the pattern contains spaces, and still needs to be enclosed in double quotes on top of that.
+ For example: r'[0-9]+'
+ """
+ await self._clean_messages(ctx, traverse, regex=regex, channels=channels, use_cache=use_cache)
@clean_group.command(name="until")
async def clean_until(
@@ -437,8 +519,8 @@ class Clean(Cog):
If a message is specified, `channel` cannot be specified.
"""
await self._clean_messages(
- CleanMessages.message_limit,
ctx,
+ CleanMessages.message_limit,
channels=[channel] if channel else None,
first_limit=until,
)
@@ -461,8 +543,8 @@ class Clean(Cog):
If a message is specified, `channel` cannot be specified.
"""
await self._clean_messages(
- CleanMessages.message_limit,
ctx,
+ CleanMessages.message_limit,
channels=[channel] if channel else None,
first_limit=first_limit,
second_limit=second_limit,