diff options
| author | 2019-12-13 14:21:56 +1000 | |
|---|---|---|
| committer | 2019-12-13 14:21:56 +1000 | |
| commit | f70cdbd8fa5cac0f20382ef0ff6cb1a05ec09df8 (patch) | |
| tree | 8ae0210c01a65b9be28a35c8f8f325f8b14d2926 | |
| parent | Merge remote-tracking branch 'origin/master' into enhancement/690-clean-cmd (diff) | |
| parent | Merge pull request #680 from python-discord/Write-unit-tests-for-`bot/utils/t... (diff) | |
Merge branch 'master' into enhancement/690-clean-cmd
| -rw-r--r-- | bot/cogs/bot.py | 4 | ||||
| -rw-r--r-- | bot/cogs/moderation/management.py | 52 | ||||
| -rw-r--r-- | bot/cogs/token_remover.py | 71 | ||||
| -rw-r--r-- | bot/converters.py | 23 | ||||
| -rw-r--r-- | bot/utils/time.py | 31 | ||||
| -rw-r--r-- | tests/bot/utils/test_time.py | 162 | 
6 files changed, 297 insertions, 46 deletions
| diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index e795e5960..73b1e8f41 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -8,6 +8,7 @@ from discord import Embed, Message, RawMessageUpdateEvent, TextChannel  from discord.ext.commands import Cog, Context, command, group  from bot.bot import Bot +from bot.cogs.token_remover import TokenRemover  from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs  from bot.decorators import with_role  from bot.utils.messages import wait_for_deletion @@ -239,9 +240,10 @@ class BotCog(Cog, name="Bot"):              )              and not msg.author.bot              and len(msg.content.splitlines()) > 3 +            and not TokenRemover.is_token_in_message(msg)          ) -        if parse_codeblock: +        if parse_codeblock:  # no token in the msg              on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300              if not on_cooldown or DEBUG_MODE:                  try: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index feae00b7c..9605d47b2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -10,7 +10,7 @@ from discord.ext.commands import Context  from bot import constants  from bot.bot import Bot -from bot.converters import InfractionSearchQuery +from bot.converters import InfractionSearchQuery, allowed_strings  from bot.pagination import LinePaginator  from bot.utils import time  from bot.utils.checks import in_channel_check, with_role_check @@ -23,15 +23,6 @@ log = logging.getLogger(__name__)  UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: -    """Only allow an expiration to be 'permanent' if it is a string.""" -    expires_at = expires_at.lower() -    if expires_at != "permanent": -        raise commands.BadArgument -    else: -        return expires_at - -  class ModManagement(commands.Cog):      """Management of infractions.""" @@ -61,8 +52,8 @@ class ModManagement(commands.Cog):      async def infraction_edit(          self,          ctx: Context, -        infraction_id: int, -        duration: t.Union[utils.Expiry, permanent_duration, None], +        infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], +        duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None],          *,          reason: str = None      ) -> None: @@ -79,21 +70,40 @@ class ModManagement(commands.Cog):          \u2003`M` - minutesā          \u2003`s` - seconds -        Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp -        can be provided for the duration. +        Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction +        authored by the command invoker should be edited. + +        Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 +        timestamp can be provided for the duration.          """          if duration is None and reason is None:              # Unlike UserInputError, the error handler will show a specified message for BadArgument              raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")          # Retrieve the previous infraction for its information. -        old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') +        if isinstance(infraction_id, str): +            params = { +                "actor__id": ctx.author.id, +                "ordering": "-inserted_at" +            } +            infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + +            if infractions: +                old_infraction = infractions[0] +                infraction_id = old_infraction["id"] +            else: +                await ctx.send( +                    f":x: Couldn't find most recent infraction; you have never given an infraction." +                ) +                return +        else: +            old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}")          request_data = {}          confirm_messages = []          log_text = "" -        if duration == "permanent": +        if isinstance(duration, str):              request_data['expires_at'] = None              confirm_messages.append("marked as permanent")          elif duration is not None: @@ -130,7 +140,8 @@ class ModManagement(commands.Cog):                  New expiry: {new_infraction['expires_at'] or "Permanent"}              """.rstrip() -        await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") +        changes = ' & '.join(confirm_messages) +        await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}")          # Get information about the infraction's user          user_id = new_infraction['user'] @@ -233,6 +244,12 @@ class ModManagement(commands.Cog):          user_id = infraction["user"]          hidden = infraction["hidden"]          created = time.format_infraction(infraction["inserted_at"]) + +        if active: +            remaining = time.until_expiration(infraction["expires_at"]) or "Expired" +        else: +            remaining = "Inactive" +          if infraction["expires_at"] is None:              expires = "*Permanent*"          else: @@ -248,6 +265,7 @@ class ModManagement(commands.Cog):              Reason: {infraction["reason"] or "*None*"}              Created: {created}              Expires: {expires} +            Remaining: {remaining}              Actor: {actor.mention if actor else actor_id}              ID: `{infraction["id"]}`              {"**===============**" if active else "==============="} diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5d6618338..82c01ae96 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -53,39 +53,60 @@ class TokenRemover(Cog):          See: https://discordapp.com/developers/docs/reference#snowflakes          """ +        if self.is_token_in_message(msg): +            await self.take_action(msg) + +    @Cog.listener() +    async def on_message_edit(self, before: Message, after: Message) -> None: +        """ +        Check each edit for a string that matches Discord's token pattern. + +        See: https://discordapp.com/developers/docs/reference#snowflakes +        """ +        if self.is_token_in_message(after): +            await self.take_action(after) + +    async def take_action(self, msg: Message) -> None: +        """Remove the `msg` containing a token an send a mod_log message.""" +        user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') +        self.mod_log.ignore(Event.message_delete, msg.id) +        await msg.delete() +        await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + +        message = ( +            "Censored a seemingly valid token sent by " +            f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " +            f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" +        ) +        log.debug(message) + +        # Send pretty mod log embed to mod-alerts +        await self.mod_log.send_log_message( +            icon_url=Icons.token_removed, +            colour=Colour(Colours.soft_red), +            title="Token removed!", +            text=message, +            thumbnail=msg.author.avatar_url_as(static_format="png"), +            channel_id=Channels.mod_alerts, +        ) + +    @classmethod +    def is_token_in_message(cls, msg: Message) -> bool: +        """Check if `msg` contains a seemly valid token."""          if msg.author.bot: -            return +            return False          maybe_match = TOKEN_RE.search(msg.content)          if maybe_match is None: -            return +            return False          try:              user_id, creation_timestamp, hmac = maybe_match.group(0).split('.')          except ValueError: -            return - -        if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): -            self.mod_log.ignore(Event.message_delete, msg.id) -            await msg.delete() -            await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - -            message = ( -                "Censored a seemingly valid token sent by " -                f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " -                f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" -            ) -            log.debug(message) - -            # Send pretty mod log embed to mod-alerts -            await self.mod_log.send_log_message( -                icon_url=Icons.token_removed, -                colour=Colour(Colours.soft_red), -                title="Token removed!", -                text=message, -                thumbnail=msg.author.avatar_url_as(static_format="png"), -                channel_id=Channels.mod_alerts, -            ) +            return False + +        if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): +            return True      @staticmethod      def is_valid_user_id(b64_content: str) -> bool: diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@  import logging  import re +import typing as t  from datetime import datetime  from ssl import CertificateError -from typing import Union  import dateutil.parser  import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter  log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: +    """ +    Return a converter which only allows arguments equal to one of the given values. + +    Unless preserve_case is True, the argument is converted to lowercase. All values are then +    expected to have already been given in lowercase too. +    """ +    def converter(arg: str) -> str: +        if not preserve_case: +            arg = arg.lower() + +        if arg not in values: +            raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") +        else: +            return arg + +    return converter + +  class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter):      """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""      @staticmethod -    async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: +    async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]:          """Check if the argument is a Discord user, and if not, falls back to a string."""          try:              maybe_snowflake = arg.strip("<@!>") diff --git a/bot/utils/time.py b/bot/utils/time.py index a024674ac..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -113,7 +113,11 @@ def format_infraction(timestamp: str) -> str:      return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) -def format_infraction_with_duration(expiry: str, date_from: datetime.datetime = None, max_units: int = 2) -> str: +def format_infraction_with_duration( +    expiry: Optional[str], +    date_from: Optional[datetime.datetime] = None, +    max_units: int = 2 +) -> Optional[str]:      """      Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. @@ -134,3 +138,28 @@ def format_infraction_with_duration(expiry: str, date_from: datetime.datetime =      duration_formatted = f" ({duration})" if duration else ''      return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration( +    expiry: Optional[str], +    now: Optional[datetime.datetime] = None, +    max_units: int = 2 +) -> Optional[str]: +    """ +    Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + +    Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. +    Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. +    `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). +    By default, max_units is 2. +    """ +    if not expiry: +        return None + +    now = now or datetime.datetime.utcnow() +    since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + +    if since < now: +        return None + +    return humanize_delta(relativedelta(since, now), max_units=max_units) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): +    """Test helper functions in bot.utils.time.""" + +    def test_humanize_delta_handle_unknown_units(self): +        """humanize_delta should be able to handle unknown units, and will not abort.""" +        # Does not abort for unknown units, as the unit name is checked +        # against the attribute of the relativedelta instance. +        self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + +    def test_humanize_delta_handle_high_units(self): +        """humanize_delta should be able to handle very high units.""" +        # Very high maximum units, but it only ever iterates over +        # each value the relativedelta might have. +        self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + +    def test_humanize_delta_should_normal_usage(self): +        """Testing humanize delta.""" +        test_cases = ( +            (relativedelta(days=2), 'seconds', 1, '2 days'), +            (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), +            (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), +            (relativedelta(days=2, hours=2), 'days', 2, '2 days'), +        ) + +        for delta, precision, max_units, expected in test_cases: +            with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): +                self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + +    def test_humanize_delta_raises_for_invalid_max_units(self): +        """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" +        test_cases = (-1, 0) + +        for max_units in test_cases: +            with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: +                time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) +                self.assertEqual(str(error), 'max_units must be positive') + +    def test_parse_rfc1123(self): +        """Testing parse_rfc1123.""" +        self.assertEqual( +            time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), +            datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) +        ) + +    def test_format_infraction(self): +        """Testing format_infraction.""" +        self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + +    @patch('asyncio.sleep', new_callable=AsyncMock) +    def test_wait_until(self, mock): +        """Testing wait_until.""" +        start = datetime(2019, 1, 1, 0, 0) +        then = datetime(2019, 1, 1, 0, 10) + +        # No return value +        self.assertIs(asyncio.run(time.wait_until(then, start)), None) + +        mock.assert_called_once_with(10 * 60) + +    def test_format_infraction_with_duration_none_expiry(self): +        """format_infraction_with_duration should work for None expiry.""" +        test_cases = ( +            (None, None, None, None), + +            # To make sure that date_from and max_units are not touched +            (None, 'Why hello there!', None, None), +            (None, None, float('inf'), None), +            (None, 'Why hello there!', float('inf'), None), +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_format_infraction_with_duration_custom_units(self): +        """format_infraction_with_duration should work for custom max_units.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, +             '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, +             '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_format_infraction_with_duration_normal_usage(self): +        """format_infraction_with_duration should work for normal usage, across various durations.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), +            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), +            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), +            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), +            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), +            ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, +             '2019-11-23 23:59 (9 minutes and 55 seconds)'), +            (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), +        ) + +        for expiry, date_from, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): +                self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + +    def test_until_expiration_with_duration_none_expiry(self): +        """until_expiration should work for None expiry.""" +        test_cases = ( +            (None, None, None, None), + +            # To make sure that now and max_units are not touched +            (None, 'Why hello there!', None, None), +            (None, None, float('inf'), None), +            (None, 'Why hello there!', float('inf'), None), +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + +    def test_until_expiration_with_duration_custom_units(self): +        """until_expiration should work for custom max_units.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + +    def test_until_expiration_normal_usage(self): +        """until_expiration should work for normal usage, across various durations.""" +        test_cases = ( +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), +            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), +            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), +            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), +            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), +            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), +            ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), +            (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), +        ) + +        for expiry, now, max_units, expected in test_cases: +            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): +                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) | 
