diff options
author | 2019-12-13 14:21:56 +1000 | |
---|---|---|
committer | 2019-12-13 14:21:56 +1000 | |
commit | f70cdbd8fa5cac0f20382ef0ff6cb1a05ec09df8 (patch) | |
tree | 8ae0210c01a65b9be28a35c8f8f325f8b14d2926 | |
parent | Merge remote-tracking branch 'origin/master' into enhancement/690-clean-cmd (diff) | |
parent | Merge pull request #680 from python-discord/Write-unit-tests-for-`bot/utils/t... (diff) |
Merge branch 'master' into enhancement/690-clean-cmd
-rw-r--r-- | bot/cogs/bot.py | 4 | ||||
-rw-r--r-- | bot/cogs/moderation/management.py | 52 | ||||
-rw-r--r-- | bot/cogs/token_remover.py | 71 | ||||
-rw-r--r-- | bot/converters.py | 23 | ||||
-rw-r--r-- | bot/utils/time.py | 31 | ||||
-rw-r--r-- | tests/bot/utils/test_time.py | 162 |
6 files changed, 297 insertions, 46 deletions
diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index e795e5960..73b1e8f41 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -8,6 +8,7 @@ from discord import Embed, Message, RawMessageUpdateEvent, TextChannel from discord.ext.commands import Cog, Context, command, group from bot.bot import Bot +from bot.cogs.token_remover import TokenRemover from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion @@ -239,9 +240,10 @@ class BotCog(Cog, name="Bot"): ) and not msg.author.bot and len(msg.content.splitlines()) > 3 + and not TokenRemover.is_token_in_message(msg) ) - if parse_codeblock: + if parse_codeblock: # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index feae00b7c..9605d47b2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context from bot import constants from bot.bot import Bot -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -23,15 +23,6 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" @@ -61,8 +52,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: @@ -79,21 +70,40 @@ class ModManagement(commands.Cog): \u2003`M` - minutesā \u2003`s` - seconds - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: @@ -130,7 +140,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] @@ -233,6 +244,12 @@ class ModManagement(commands.Cog): user_id = infraction["user"] hidden = infraction["hidden"] created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + if infraction["expires_at"] is None: expires = "*Permanent*" else: @@ -248,6 +265,7 @@ class ModManagement(commands.Cog): Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} + Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` {"**===============**" if active else "==============="} diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5d6618338..82c01ae96 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -53,39 +53,60 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ + if self.is_token_in_message(msg): + await self.take_action(msg) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + if self.is_token_in_message(after): + await self.take_action(after) + + async def take_action(self, msg: Message) -> None: + """Remove the `msg` containing a token an send a mod_log message.""" + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') + self.mod_log.ignore(Event.message_delete, msg.id) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + message = ( + "Censored a seemingly valid token sent by " + f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " + f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" + ) + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + @classmethod + def is_token_in_message(cls, msg: Message) -> bool: + """Check if `msg` contains a seemly valid token.""" if msg.author.bot: - return + return False maybe_match = TOKEN_RE.search(msg.content) if maybe_match is None: - return + return False try: user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') except ValueError: - return - - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): - self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - message = ( - "Censored a seemingly valid token sent by " - f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " - f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" - ) - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) + return False + + if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): + return True @staticmethod def is_valid_user_id(b64_content: str) -> bool: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converted to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str: return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( + expiry: Optional[str], + date_from: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: """ Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. @@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = duration_formatted = f" ({duration})" if duration else '' return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration( + expiry: Optional[str], + now: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: + """ + Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + + Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + now = now or datetime.datetime.utcnow() + since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + if since < now: + return None + + return humanize_delta(relativedelta(since, now), max_units=max_units) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + + def test_humanize_delta_should_normal_usage(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" + test_cases = (-1, 0) + + for max_units in test_cases: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) + ) + + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + self.assertIs(asyncio.run(time.wait_until(then, start)), None) + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that now and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) |