diff options
author | 2019-09-23 21:13:47 +0200 | |
---|---|---|
committer | 2019-09-23 21:31:19 +0200 | |
commit | fba165037943fda90039ec9cadf0649cfae0e781 (patch) | |
tree | 4469b2b431316ca7e4829fdbc21aeb52844eba0d | |
parent | Make DEFCON days subcommand enable DEFCON (#405) (diff) |
Fix failing duration conversion
https://github.com/python-discord/bot/issues/446
The current ExpirationDate converter does not convert duration strings
to `datetime.datetime` objects correctly. To remedy the problem, I've
written a new Duration converter that uses regex matching to extract
the relevant duration units and `dateutil.relativedelta.relativedelta`
to compute a `datetime.datetime` that's the given duration in the
future.
I've left the old `ExpirationDate` converter in place for now, since
the new Duration converter may not be the most optimal method. However,
given the importance of being able to convert durations for moderation
tasks, I think it's better to implement Duration now and rethink the
approach later.
This commit closes #446
-rw-r--r-- | bot/cogs/antispam.py | 4 | ||||
-rw-r--r-- | bot/cogs/moderation.py | 12 | ||||
-rw-r--r-- | bot/cogs/reminders.py | 8 | ||||
-rw-r--r-- | bot/cogs/superstarify/__init__.py | 4 | ||||
-rw-r--r-- | bot/converters.py | 35 | ||||
-rw-r--r-- | tests/test_converters.py | 82 |
6 files changed, 127 insertions, 18 deletions
diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 7a3360436..8dfa0ad05 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -17,7 +17,7 @@ from bot.constants import ( Guild as GuildConfig, Icons, STAFF_ROLES, ) -from bot.converters import ExpirationDate +from bot.converters import Duration log = logging.getLogger(__name__) @@ -102,7 +102,7 @@ class AntiSpam(Cog): self.validation_errors = validation_errors role_id = AntiSpamConfig.punishment['role_id'] self.muted_role = Object(role_id) - self.expiration_date_converter = ExpirationDate() + self.expiration_date_converter = Duration() self.message_deletion_queue = dict() self.queue_consumption_tasks = dict() diff --git a/bot/cogs/moderation.py b/bot/cogs/moderation.py index 81b3864a7..4d651bef7 100644 --- a/bot/cogs/moderation.py +++ b/bot/cogs/moderation.py @@ -14,7 +14,7 @@ from discord.ext.commands import ( from bot import constants from bot.cogs.modlog import ModLog from bot.constants import Colours, Event, Icons, MODERATION_ROLES -from bot.converters import ExpirationDate, InfractionSearchQuery +from bot.converters import Duration, InfractionSearchQuery from bot.decorators import with_role from bot.pagination import LinePaginator from bot.utils.moderation import already_has_active_infraction, post_infraction @@ -279,7 +279,7 @@ class Moderation(Scheduler, Cog): @with_role(*MODERATION_ROLES) @command() - async def tempmute(self, ctx: Context, user: Member, duration: ExpirationDate, *, reason: str = None) -> None: + async def tempmute(self, ctx: Context, user: Member, duration: Duration, *, reason: str = None) -> None: """ Create a temporary mute infraction for a user with the provided expiration and reason. @@ -345,7 +345,7 @@ class Moderation(Scheduler, Cog): @with_role(*MODERATION_ROLES) @command() - async def tempban(self, ctx: Context, user: UserTypes, duration: ExpirationDate, *, reason: str = None) -> None: + async def tempban(self, ctx: Context, user: UserTypes, duration: Duration, *, reason: str = None) -> None: """ Create a temporary ban infraction for a user with the provided expiration and reason. @@ -600,7 +600,7 @@ class Moderation(Scheduler, Cog): @with_role(*MODERATION_ROLES) @command(hidden=True, aliases=["shadowtempmute, stempmute"]) async def shadow_tempmute( - self, ctx: Context, user: Member, duration: ExpirationDate, *, reason: str = None + self, ctx: Context, user: Member, duration: Duration, *, reason: str = None ) -> None: """ Create a temporary mute infraction for a user with the provided reason. @@ -653,7 +653,7 @@ class Moderation(Scheduler, Cog): @with_role(*MODERATION_ROLES) @command(hidden=True, aliases=["shadowtempban, stempban"]) async def shadow_tempban( - self, ctx: Context, user: UserTypes, duration: ExpirationDate, *, reason: str = None + self, ctx: Context, user: UserTypes, duration: Duration, *, reason: str = None ) -> None: """ Create a temporary ban infraction for a user with the provided reason. @@ -884,7 +884,7 @@ class Moderation(Scheduler, Cog): @infraction_edit_group.command(name="duration") async def edit_duration( self, ctx: Context, - infraction_id: int, expires_at: Union[ExpirationDate, str] + infraction_id: int, expires_at: Union[Duration, str] ) -> None: """ Sets the duration of the given infraction, relative to the time of updating. diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 8460de91f..c37abf21e 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -11,7 +11,7 @@ from discord import Colour, Embed, Message from discord.ext.commands import Bot, Cog, Context, group from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES -from bot.converters import ExpirationDate +from bot.converters import Duration from bot.pagination import LinePaginator from bot.utils.checks import without_role_check from bot.utils.scheduling import Scheduler @@ -118,12 +118,12 @@ class Reminders(Scheduler, Cog): await self._delete_reminder(reminder["id"]) @group(name="remind", aliases=("reminder", "reminders"), invoke_without_command=True) - async def remind_group(self, ctx: Context, expiration: ExpirationDate, *, content: str) -> None: + async def remind_group(self, ctx: Context, expiration: Duration, *, content: str) -> None: """Commands for managing your reminders.""" await ctx.invoke(self.new_reminder, expiration=expiration, content=content) @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder(self, ctx: Context, expiration: ExpirationDate, *, content: str) -> Optional[Message]: + async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> Optional[Message]: """ Set yourself a simple reminder. @@ -237,7 +237,7 @@ class Reminders(Scheduler, Cog): await ctx.invoke(self.bot.get_command("help"), "reminders", "edit") @edit_reminder_group.command(name="duration", aliases=("time",)) - async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: ExpirationDate) -> None: + async def edit_reminder_duration(self, ctx: Context, id_: int, expiration: Duration) -> None: """ Edit one of your reminder's expiration. diff --git a/bot/cogs/superstarify/__init__.py b/bot/cogs/superstarify/__init__.py index f7d6a269d..b1936ef3a 100644 --- a/bot/cogs/superstarify/__init__.py +++ b/bot/cogs/superstarify/__init__.py @@ -10,7 +10,7 @@ from bot.cogs.moderation import Moderation from bot.cogs.modlog import ModLog from bot.cogs.superstarify.stars import get_nick from bot.constants import Icons, MODERATION_ROLES, POSITIVE_REPLIES -from bot.converters import ExpirationDate +from bot.converters import Duration from bot.decorators import with_role from bot.utils.moderation import post_infraction @@ -153,7 +153,7 @@ class Superstarify(Cog): @command(name='superstarify', aliases=('force_nick', 'star')) @with_role(*MODERATION_ROLES) async def superstarify( - self, ctx: Context, member: Member, expiration: ExpirationDate, reason: str = None + self, ctx: Context, member: Member, expiration: Duration, reason: str = None ) -> None: """ Force a random superstar name (like Taylor Swift) to be the user's nickname for a specified duration. diff --git a/bot/converters.py b/bot/converters.py index 7386187ab..b7340982b 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,4 +1,5 @@ import logging +import re from datetime import datetime from ssl import CertificateError from typing import Union @@ -6,6 +7,7 @@ from typing import Union import dateparser import discord from aiohttp import ClientConnectorError +from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Context, Converter @@ -197,3 +199,36 @@ class ExpirationDate(Converter): expiry = now + (now - expiry) return expiry + + +class Duration(Converter): + """Convert duration strings into UTC datetime.datetime objects.""" + + duration_parser = re.compile( + r"((?P<years>\d+?)(years|year|Y|y))?" + r"((?P<months>\d+?)(months|month|m))?" + r"((?P<weeks>\d+?)(weeks|week|W|w))?" + r"((?P<days>\d+?)(days|day|D|d))?" + r"((?P<hours>\d+?)(hours|hour|H|h))?" + r"((?P<minutes>\d+?)(minutes|minute|M))?" + r"((?P<seconds>\d+?)(seconds|second|S|s))?" + ) + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports years (symbols: `years`, `year, `Y`, `y`), months (`months`, `month`, + `m`), weeks (`weeks`, `week`, `W`, `w`), days (`days`, `day`, `D`, `d`), hours (`hours`, + `hour`, `H`, `h`), minutes (`minutes`, `minute`, `M`), and seconds (`seconds`, `second`, + `S`, `s`), The units must be provided in descending order of magnitude. + """ + match = self.duration_parser.fullmatch(duration) + if not match: + raise BadArgument(f"`{duration}` is not a valid duration string.") + + duration_dict = {unit: int(amount) for unit, amount in match.groupdict().items() if amount} + delta = relativedelta(**duration_dict) + now = datetime.utcnow() + + return now + delta diff --git a/tests/test_converters.py b/tests/test_converters.py index 3cf774c80..3cf00035f 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -1,11 +1,12 @@ import asyncio -from datetime import datetime -from unittest.mock import MagicMock +import datetime +from unittest.mock import MagicMock, patch import pytest from discord.ext.commands import BadArgument from bot.converters import ( + Duration, ExpirationDate, TagContentConverter, TagNameConverter, @@ -17,10 +18,10 @@ from bot.converters import ( ('value', 'expected'), ( # sorry aliens - ('2199-01-01T00:00:00', datetime(2199, 1, 1)), + ('2199-01-01T00:00:00', datetime.datetime(2199, 1, 1)), ) ) -def test_expiration_date_converter_for_valid(value: str, expected: datetime): +def test_expiration_date_converter_for_valid(value: str, expected: datetime.datetime): converter = ExpirationDate() assert asyncio.run(converter.convert(None, value)) == expected @@ -91,3 +92,76 @@ def test_valid_python_identifier_for_valid(value: str): def test_valid_python_identifier_for_invalid(value: str): with pytest.raises(BadArgument, match=f'`{value}` is not a valid Python identifier'): asyncio.run(ValidPythonIdentifier.convert(None, value)) + + +FIXED_UTC_NOW = datetime.datetime.fromisoformat('2019-01-01T00:00:00') + + + ('duration', 'expected'), + ( + # Simple duration strings + ('1Y', datetime.datetime.fromisoformat('2020-01-01T00:00:00')), + ('1y', datetime.datetime.fromisoformat('2020-01-01T00:00:00')), + ('1year', datetime.datetime.fromisoformat('2020-01-01T00:00:00')), + ('1years', datetime.datetime.fromisoformat('2020-01-01T00:00:00')), + ('1m', datetime.datetime.fromisoformat('2019-02-01T00:00:00')), + ('1month', datetime.datetime.fromisoformat('2019-02-01T00:00:00')), + ('1months', datetime.datetime.fromisoformat('2019-02-01T00:00:00')), + ('1w', datetime.datetime.fromisoformat('2019-01-08T00:00:00')), + ('1W', datetime.datetime.fromisoformat('2019-01-08T00:00:00')), + ('1week', datetime.datetime.fromisoformat('2019-01-08T00:00:00')), + ('1weeks', datetime.datetime.fromisoformat('2019-01-08T00:00:00')), + ('1d', datetime.datetime.fromisoformat('2019-01-02T00:00:00')), + ('1D', datetime.datetime.fromisoformat('2019-01-02T00:00:00')), + ('1day', datetime.datetime.fromisoformat('2019-01-02T00:00:00')), + ('1days', datetime.datetime.fromisoformat('2019-01-02T00:00:00')), + ('1h', datetime.datetime.fromisoformat('2019-01-01T01:00:00')), + ('1H', datetime.datetime.fromisoformat('2019-01-01T01:00:00')), + ('1hour', datetime.datetime.fromisoformat('2019-01-01T01:00:00')), + ('1hours', datetime.datetime.fromisoformat('2019-01-01T01:00:00')), + ('1M', datetime.datetime.fromisoformat('2019-01-01T00:01:00')), + ('1minute', datetime.datetime.fromisoformat('2019-01-01T00:01:00')), + ('1minutes', datetime.datetime.fromisoformat('2019-01-01T00:01:00')), + ('1s', datetime.datetime.fromisoformat('2019-01-01T00:00:01')), + ('1S', datetime.datetime.fromisoformat('2019-01-01T00:00:01')), + ('1second', datetime.datetime.fromisoformat('2019-01-01T00:00:01')), + ('1seconds', datetime.datetime.fromisoformat('2019-01-01T00:00:01')), + + # Complex duration strings + ('1y1m1w1d1H1M1S', datetime.datetime.fromisoformat('2020-02-09T01:01:01')), + ('5y100S', datetime.datetime.fromisoformat('2024-01-01T00:01:40')), + ('2w28H', datetime.datetime.fromisoformat('2019-01-16T04:00:00')), + ) +) +def test_duration_converter_for_valid(duration: str, expected: datetime): + converter = Duration() + + with patch('bot.converters.datetime') as mock_datetime: + mock_datetime.utcnow.return_value = FIXED_UTC_NOW + assert asyncio.run(converter.convert(None, duration)) == expected + + + ('duration'), + ( + # Units in wrong order + ('1d1w'), + ('1s1y'), + + # Unknown substrings + ('1MVes'), + ('1y3breads'), + + # Missing amount + ('ym'), + + # Garbage + ('Guido van Rossum'), + ('lemon lemon lemon lemon lemon lemon lemon'), + ) +) +def test_duration_converter_for_invalid(duration: str): + converter = Duration() + with pytest.raises(BadArgument, match=f'`{duration}` is not a valid duration string.'): + asyncio.run(converter.convert(None, duration)) |