diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 121 |
1 files changed, 82 insertions, 39 deletions
diff --git a/bot/converters.py b/bot/converters.py index 4bd9aba13..339da7b60 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,10 +1,12 @@ import logging +import re from datetime import datetime from ssl import CertificateError +from typing import Union -import dateparser import discord from aiohttp import ClientConnectorError +from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Context, Converter @@ -15,17 +17,16 @@ class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. - This is used to have package names - that correspond to how you would use - the package in your code, e.g. - `import package`. Raises `BadArgument` - if the argument is not a valid Python - identifier, and simply passes through + This is used to have package names that correspond to how you would use the package in your + code, e.g. `import package`. + + Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through the given argument otherwise. """ @staticmethod - async def convert(ctx, argument: str): + async def convert(ctx: Context, argument: str) -> str: + """Checks whether the given string is a valid Python identifier.""" if not argument.isidentifier(): raise BadArgument(f"`{argument}` is not a valid Python identifier") return argument @@ -35,14 +36,15 @@ class ValidURL(Converter): """ Represents a valid webpage URL. - This converter checks whether the given - URL can be reached and requesting it returns - a status code of 200. If not, `BadArgument` - is raised. Otherwise, it simply passes through the given URL. + This converter checks whether the given URL can be reached and requesting it returns a status + code of 200. If not, `BadArgument` is raised. + + Otherwise, it simply passes through the given URL. """ @staticmethod - async def convert(ctx, url: str): + async def convert(ctx: Context, url: str) -> str: + """This converter checks whether the given URL can be reached with a status code of 200.""" try: async with ctx.bot.http_session.get(url) as resp: if resp.status != 200: @@ -63,12 +65,11 @@ class ValidURL(Converter): class InfractionSearchQuery(Converter): - """ - A converter that checks if the argument is a Discord user, and if not, falls back to a string. - """ + """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx, arg): + async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") return await ctx.bot.fetch_user(maybe_snowflake) @@ -77,12 +78,15 @@ class InfractionSearchQuery(Converter): class Subreddit(Converter): - """ - Forces a string to begin with "r/" and checks if it's a valid subreddit. - """ + """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" @staticmethod - async def convert(ctx, sub: str): + async def convert(ctx: Context, sub: str) -> str: + """ + Force sub to begin with "r/" and check if it's a valid subreddit. + + If sub is a valid subreddit, return it prepended with "r/" + """ sub = sub.lower() if not sub.startswith("r/"): @@ -103,9 +107,21 @@ class Subreddit(Converter): class TagNameConverter(Converter): + """ + Ensure that a proposed tag name is valid. + + Valid tag names meet the following conditions: + * All ASCII characters + * Has at least one non-whitespace character + * Not solely numeric + * Shorter than 127 characters + """ + @staticmethod - async def convert(ctx: Context, tag_name: str): - def is_number(value): + async def convert(ctx: Context, tag_name: str) -> str: + """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" + def is_number(value: str) -> bool: + """Check to see if the input string is numeric.""" try: float(value) except ValueError: @@ -142,8 +158,15 @@ class TagNameConverter(Converter): class TagContentConverter(Converter): + """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" + @staticmethod - async def convert(ctx: Context, tag_content: str): + async def convert(ctx: Context, tag_content: str) -> str: + """ + Ensure tag_content is non-empty and contains at least one non-whitespace character. + + If tag_content is valid, return the stripped version. + """ tag_content = tag_content.strip() # The tag contents should not be empty, or filled with whitespace. @@ -155,20 +178,40 @@ class TagContentConverter(Converter): return tag_content -class ExpirationDate(Converter): - DATEPARSER_SETTINGS = { - 'PREFER_DATES_FROM': 'future', - 'TIMEZONE': 'UTC', - 'TO_TIMEZONE': 'UTC' - } - - async def convert(self, ctx, expiration_string: str): - expiry = dateparser.parse(expiration_string, settings=self.DATEPARSER_SETTINGS) - if expiry is None: - raise BadArgument(f"Failed to parse expiration date from `{expiration_string}`") - +class Duration(Converter): + """Convert duration strings into UTC datetime.datetime objects.""" + + duration_parser = re.compile( + r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" + r"((?P<months>\d+?) ?(months|month|m) ?)?" + r"((?P<weeks>\d+?) ?(weeks|week|W|w) ?)?" + r"((?P<days>\d+?) ?(days|day|D|d) ?)?" + r"((?P<hours>\d+?) ?(hours|hour|H|h) ?)?" + r"((?P<minutes>\d+?) ?(minutes|minute|M) ?)?" + r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" + ) + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + The units need to be provided in descending order of magnitude. + """ + match = self.duration_parser.fullmatch(duration) + if not match: + raise BadArgument(f"`{duration}` is not a valid duration string.") + + duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} + delta = relativedelta(**duration_dict) now = datetime.utcnow() - if expiry < now: - expiry = now + (now - expiry) - return expiry + return now + delta |