aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar scragly <[email protected]>2019-12-13 14:21:56 +1000
committerGravatar GitHub <[email protected]>2019-12-13 14:21:56 +1000
commitf70cdbd8fa5cac0f20382ef0ff6cb1a05ec09df8 (patch)
tree8ae0210c01a65b9be28a35c8f8f325f8b14d2926
parentMerge remote-tracking branch 'origin/master' into enhancement/690-clean-cmd (diff)
parentMerge pull request #680 from python-discord/Write-unit-tests-for-`bot/utils/t... (diff)
Merge branch 'master' into enhancement/690-clean-cmd
-rw-r--r--bot/cogs/bot.py4
-rw-r--r--bot/cogs/moderation/management.py52
-rw-r--r--bot/cogs/token_remover.py71
-rw-r--r--bot/converters.py23
-rw-r--r--bot/utils/time.py31
-rw-r--r--tests/bot/utils/test_time.py162
6 files changed, 297 insertions, 46 deletions
diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py
index e795e5960..73b1e8f41 100644
--- a/bot/cogs/bot.py
+++ b/bot/cogs/bot.py
@@ -8,6 +8,7 @@ from discord import Embed, Message, RawMessageUpdateEvent, TextChannel
from discord.ext.commands import Cog, Context, command, group
from bot.bot import Bot
+from bot.cogs.token_remover import TokenRemover
from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs
from bot.decorators import with_role
from bot.utils.messages import wait_for_deletion
@@ -239,9 +240,10 @@ class BotCog(Cog, name="Bot"):
)
and not msg.author.bot
and len(msg.content.splitlines()) > 3
+ and not TokenRemover.is_token_in_message(msg)
)
- if parse_codeblock:
+ if parse_codeblock: # no token in the msg
on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300
if not on_cooldown or DEBUG_MODE:
try:
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index feae00b7c..9605d47b2 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -10,7 +10,7 @@ from discord.ext.commands import Context
from bot import constants
from bot.bot import Bot
-from bot.converters import InfractionSearchQuery
+from bot.converters import InfractionSearchQuery, allowed_strings
from bot.pagination import LinePaginator
from bot.utils import time
from bot.utils.checks import in_channel_check, with_role_check
@@ -23,15 +23,6 @@ log = logging.getLogger(__name__)
UserConverter = t.Union[discord.User, utils.proxy_user]
-def permanent_duration(expires_at: str) -> str:
- """Only allow an expiration to be 'permanent' if it is a string."""
- expires_at = expires_at.lower()
- if expires_at != "permanent":
- raise commands.BadArgument
- else:
- return expires_at
-
-
class ModManagement(commands.Cog):
"""Management of infractions."""
@@ -61,8 +52,8 @@ class ModManagement(commands.Cog):
async def infraction_edit(
self,
ctx: Context,
- infraction_id: int,
- duration: t.Union[utils.Expiry, permanent_duration, None],
+ infraction_id: t.Union[int, allowed_strings("l", "last", "recent")],
+ duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],
*,
reason: str = None
) -> None:
@@ -79,21 +70,40 @@ class ModManagement(commands.Cog):
\u2003`M` - minutesāˆ—
\u2003`s` - seconds
- Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp
- can be provided for the duration.
+ Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction
+ authored by the command invoker should be edited.
+
+ Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601
+ timestamp can be provided for the duration.
"""
if duration is None and reason is None:
# Unlike UserInputError, the error handler will show a specified message for BadArgument
raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")
# Retrieve the previous infraction for its information.
- old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}')
+ if isinstance(infraction_id, str):
+ params = {
+ "actor__id": ctx.author.id,
+ "ordering": "-inserted_at"
+ }
+ infractions = await self.bot.api_client.get(f"bot/infractions", params=params)
+
+ if infractions:
+ old_infraction = infractions[0]
+ infraction_id = old_infraction["id"]
+ else:
+ await ctx.send(
+ f":x: Couldn't find most recent infraction; you have never given an infraction."
+ )
+ return
+ else:
+ old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")
request_data = {}
confirm_messages = []
log_text = ""
- if duration == "permanent":
+ if isinstance(duration, str):
request_data['expires_at'] = None
confirm_messages.append("marked as permanent")
elif duration is not None:
@@ -130,7 +140,8 @@ class ModManagement(commands.Cog):
New expiry: {new_infraction['expires_at'] or "Permanent"}
""".rstrip()
- await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}")
+ changes = ' & '.join(confirm_messages)
+ await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")
# Get information about the infraction's user
user_id = new_infraction['user']
@@ -233,6 +244,12 @@ class ModManagement(commands.Cog):
user_id = infraction["user"]
hidden = infraction["hidden"]
created = time.format_infraction(infraction["inserted_at"])
+
+ if active:
+ remaining = time.until_expiration(infraction["expires_at"]) or "Expired"
+ else:
+ remaining = "Inactive"
+
if infraction["expires_at"] is None:
expires = "*Permanent*"
else:
@@ -248,6 +265,7 @@ class ModManagement(commands.Cog):
Reason: {infraction["reason"] or "*None*"}
Created: {created}
Expires: {expires}
+ Remaining: {remaining}
Actor: {actor.mention if actor else actor_id}
ID: `{infraction["id"]}`
{"**===============**" if active else "==============="}
diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py
index 5d6618338..82c01ae96 100644
--- a/bot/cogs/token_remover.py
+++ b/bot/cogs/token_remover.py
@@ -53,39 +53,60 @@ class TokenRemover(Cog):
See: https://discordapp.com/developers/docs/reference#snowflakes
"""
+ if self.is_token_in_message(msg):
+ await self.take_action(msg)
+
+ @Cog.listener()
+ async def on_message_edit(self, before: Message, after: Message) -> None:
+ """
+ Check each edit for a string that matches Discord's token pattern.
+
+ See: https://discordapp.com/developers/docs/reference#snowflakes
+ """
+ if self.is_token_in_message(after):
+ await self.take_action(after)
+
+ async def take_action(self, msg: Message) -> None:
+ """Remove the `msg` containing a token an send a mod_log message."""
+ user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.')
+ self.mod_log.ignore(Event.message_delete, msg.id)
+ await msg.delete()
+ await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
+
+ message = (
+ "Censored a seemingly valid token sent by "
+ f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was "
+ f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`"
+ )
+ log.debug(message)
+
+ # Send pretty mod log embed to mod-alerts
+ await self.mod_log.send_log_message(
+ icon_url=Icons.token_removed,
+ colour=Colour(Colours.soft_red),
+ title="Token removed!",
+ text=message,
+ thumbnail=msg.author.avatar_url_as(static_format="png"),
+ channel_id=Channels.mod_alerts,
+ )
+
+ @classmethod
+ def is_token_in_message(cls, msg: Message) -> bool:
+ """Check if `msg` contains a seemly valid token."""
if msg.author.bot:
- return
+ return False
maybe_match = TOKEN_RE.search(msg.content)
if maybe_match is None:
- return
+ return False
try:
user_id, creation_timestamp, hmac = maybe_match.group(0).split('.')
except ValueError:
- return
-
- if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp):
- self.mod_log.ignore(Event.message_delete, msg.id)
- await msg.delete()
- await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention))
-
- message = (
- "Censored a seemingly valid token sent by "
- f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was "
- f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`"
- )
- log.debug(message)
-
- # Send pretty mod log embed to mod-alerts
- await self.mod_log.send_log_message(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=message,
- thumbnail=msg.author.avatar_url_as(static_format="png"),
- channel_id=Channels.mod_alerts,
- )
+ return False
+
+ if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp):
+ return True
@staticmethod
def is_valid_user_id(b64_content: str) -> bool:
diff --git a/bot/converters.py b/bot/converters.py
index cf0496541..8d2ab7eb8 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,8 +1,8 @@
import logging
import re
+import typing as t
from datetime import datetime
from ssl import CertificateError
-from typing import Union
import dateutil.parser
import dateutil.tz
@@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter
log = logging.getLogger(__name__)
+def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]:
+ """
+ Return a converter which only allows arguments equal to one of the given values.
+
+ Unless preserve_case is True, the argument is converted to lowercase. All values are then
+ expected to have already been given in lowercase too.
+ """
+ def converter(arg: str) -> str:
+ if not preserve_case:
+ arg = arg.lower()
+
+ if arg not in values:
+ raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```")
+ else:
+ return arg
+
+ return converter
+
+
class ValidPythonIdentifier(Converter):
"""
A converter that checks whether the given string is a valid Python identifier.
@@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):
"""A converter that checks if the argument is a Discord user, and if not, falls back to a string."""
@staticmethod
- async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]:
+ async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:
"""Check if the argument is a Discord user, and if not, falls back to a string."""
try:
maybe_snowflake = arg.strip("<@!>")
diff --git a/bot/utils/time.py b/bot/utils/time.py
index a024674ac..7416f36e0 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str:
return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT)
-def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str:
+def format_infraction_with_duration(
+ expiry: Optional[str],
+ date_from: Optional[datetime.datetime] = None,
+ max_units: int = 2
+) -> Optional[str]:
"""
Format an infraction timestamp to a more readable ISO 8601 format WITH the duration.
@@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime =
duration_formatted = f" ({duration})" if duration else ''
return f"{expiry_formatted}{duration_formatted}"
+
+
+def until_expiration(
+ expiry: Optional[str],
+ now: Optional[datetime.datetime] = None,
+ max_units: int = 2
+) -> Optional[str]:
+ """
+ Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta.
+
+ Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry.
+ Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it.
+ `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
+ By default, max_units is 2.
+ """
+ if not expiry:
+ return None
+
+ now = now or datetime.datetime.utcnow()
+ since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0)
+
+ if since < now:
+ return None
+
+ return humanize_delta(relativedelta(since, now), max_units=max_units)
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
new file mode 100644
index 000000000..69f35f2f5
--- /dev/null
+++ b/tests/bot/utils/test_time.py
@@ -0,0 +1,162 @@
+import asyncio
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from dateutil.relativedelta import relativedelta
+
+from bot.utils import time
+from tests.helpers import AsyncMock
+
+
+class TimeTests(unittest.TestCase):
+ """Test helper functions in bot.utils.time."""
+
+ def test_humanize_delta_handle_unknown_units(self):
+ """humanize_delta should be able to handle unknown units, and will not abort."""
+ # Does not abort for unknown units, as the unit name is checked
+ # against the attribute of the relativedelta instance.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours')
+
+ def test_humanize_delta_handle_high_units(self):
+ """humanize_delta should be able to handle very high units."""
+ # Very high maximum units, but it only ever iterates over
+ # each value the relativedelta might have.
+ self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours')
+
+ def test_humanize_delta_should_normal_usage(self):
+ """Testing humanize delta."""
+ test_cases = (
+ (relativedelta(days=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
+ (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
+ )
+
+ for delta, precision, max_units, expected in test_cases:
+ with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected):
+ self.assertEqual(time.humanize_delta(delta, precision, max_units), expected)
+
+ def test_humanize_delta_raises_for_invalid_max_units(self):
+ """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units."""
+ test_cases = (-1, 0)
+
+ for max_units in test_cases:
+ with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
+ time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
+ self.assertEqual(str(error), 'max_units must be positive')
+
+ def test_parse_rfc1123(self):
+ """Testing parse_rfc1123."""
+ self.assertEqual(
+ time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'),
+ datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)
+ )
+
+ def test_format_infraction(self):
+ """Testing format_infraction."""
+ self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
+
+ @patch('asyncio.sleep', new_callable=AsyncMock)
+ def test_wait_until(self, mock):
+ """Testing wait_until."""
+ start = datetime(2019, 1, 1, 0, 0)
+ then = datetime(2019, 1, 1, 0, 10)
+
+ # No return value
+ self.assertIs(asyncio.run(time.wait_until(then, start)), None)
+
+ mock.assert_called_once_with(10 * 60)
+
+ def test_format_infraction_with_duration_none_expiry(self):
+ """format_infraction_with_duration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that date_from and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_custom_units(self):
+ """format_infraction_with_duration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6,
+ '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20,
+ '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)')
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_format_infraction_with_duration_normal_usage(self):
+ """format_infraction_with_duration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2,
+ '2019-11-23 23:59 (9 minutes and 55 seconds)'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, date_from, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
+ self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+
+ def test_until_expiration_with_duration_none_expiry(self):
+ """until_expiration should work for None expiry."""
+ test_cases = (
+ (None, None, None, None),
+
+ # To make sure that now and max_units are not touched
+ (None, 'Why hello there!', None, None),
+ (None, None, float('inf'), None),
+ (None, 'Why hello there!', float('inf'), None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_with_duration_custom_units(self):
+ """until_expiration should work for custom max_units."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes')
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+
+ def test_until_expiration_normal_usage(self):
+ """until_expiration should work for normal usage, across various durations."""
+ test_cases = (
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'),
+ (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ )
+
+ for expiry, now, max_units, expected in test_cases:
+ with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
+ self.assertEqual(time.until_expiration(expiry, now, max_units), expected)