diff options
Diffstat (limited to 'bot/converters.py')
| -rw-r--r-- | bot/converters.py | 121 | 
1 files changed, 82 insertions, 39 deletions
| diff --git a/bot/converters.py b/bot/converters.py index 4bd9aba13..339da7b60 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,10 +1,12 @@  import logging +import re  from datetime import datetime  from ssl import CertificateError +from typing import Union -import dateparser  import discord  from aiohttp import ClientConnectorError +from dateutil.relativedelta import relativedelta  from discord.ext.commands import BadArgument, Context, Converter @@ -15,17 +17,16 @@ class ValidPythonIdentifier(Converter):      """      A converter that checks whether the given string is a valid Python identifier. -    This is used to have package names -    that correspond to how you would use -    the package in your code, e.g. -    `import package`. Raises `BadArgument` -    if the argument is not a valid Python -    identifier, and simply passes through +    This is used to have package names that correspond to how you would use the package in your +    code, e.g. `import package`. + +    Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through      the given argument otherwise.      """      @staticmethod -    async def convert(ctx, argument: str): +    async def convert(ctx: Context, argument: str) -> str: +        """Checks whether the given string is a valid Python identifier."""          if not argument.isidentifier():              raise BadArgument(f"`{argument}` is not a valid Python identifier")          return argument @@ -35,14 +36,15 @@ class ValidURL(Converter):      """      Represents a valid webpage URL. -    This converter checks whether the given -    URL can be reached and requesting it returns -    a status code of 200. If not, `BadArgument` -    is raised. Otherwise, it simply passes through the given URL. +    This converter checks whether the given URL can be reached and requesting it returns a status +    code of 200. If not, `BadArgument` is raised. + +    Otherwise, it simply passes through the given URL.      """      @staticmethod -    async def convert(ctx, url: str): +    async def convert(ctx: Context, url: str) -> str: +        """This converter checks whether the given URL can be reached with a status code of 200."""          try:              async with ctx.bot.http_session.get(url) as resp:                  if resp.status != 200: @@ -63,12 +65,11 @@ class ValidURL(Converter):  class InfractionSearchQuery(Converter): -    """ -    A converter that checks if the argument is a Discord user, and if not, falls back to a string. -    """ +    """A converter that checks if the argument is a Discord user, and if not, falls back to a string."""      @staticmethod -    async def convert(ctx, arg): +    async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: +        """Check if the argument is a Discord user, and if not, falls back to a string."""          try:              maybe_snowflake = arg.strip("<@!>")              return await ctx.bot.fetch_user(maybe_snowflake) @@ -77,12 +78,15 @@ class InfractionSearchQuery(Converter):  class Subreddit(Converter): -    """ -    Forces a string to begin with "r/" and checks if it's a valid subreddit. -    """ +    """Forces a string to begin with "r/" and checks if it's a valid subreddit."""      @staticmethod -    async def convert(ctx, sub: str): +    async def convert(ctx: Context, sub: str) -> str: +        """ +        Force sub to begin with "r/" and check if it's a valid subreddit. + +        If sub is a valid subreddit, return it prepended with "r/" +        """          sub = sub.lower()          if not sub.startswith("r/"): @@ -103,9 +107,21 @@ class Subreddit(Converter):  class TagNameConverter(Converter): +    """ +    Ensure that a proposed tag name is valid. + +    Valid tag names meet the following conditions: +        * All ASCII characters +        * Has at least one non-whitespace character +        * Not solely numeric +        * Shorter than 127 characters +    """ +      @staticmethod -    async def convert(ctx: Context, tag_name: str): -        def is_number(value): +    async def convert(ctx: Context, tag_name: str) -> str: +        """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" +        def is_number(value: str) -> bool: +            """Check to see if the input string is numeric."""              try:                  float(value)              except ValueError: @@ -142,8 +158,15 @@ class TagNameConverter(Converter):  class TagContentConverter(Converter): +    """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" +      @staticmethod -    async def convert(ctx: Context, tag_content: str): +    async def convert(ctx: Context, tag_content: str) -> str: +        """ +        Ensure tag_content is non-empty and contains at least one non-whitespace character. + +        If tag_content is valid, return the stripped version. +        """          tag_content = tag_content.strip()          # The tag contents should not be empty, or filled with whitespace. @@ -155,20 +178,40 @@ class TagContentConverter(Converter):          return tag_content -class ExpirationDate(Converter): -    DATEPARSER_SETTINGS = { -        'PREFER_DATES_FROM': 'future', -        'TIMEZONE': 'UTC', -        'TO_TIMEZONE': 'UTC' -    } - -    async def convert(self, ctx, expiration_string: str): -        expiry = dateparser.parse(expiration_string, settings=self.DATEPARSER_SETTINGS) -        if expiry is None: -            raise BadArgument(f"Failed to parse expiration date from `{expiration_string}`") - +class Duration(Converter): +    """Convert duration strings into UTC datetime.datetime objects.""" + +    duration_parser = re.compile( +        r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" +        r"((?P<months>\d+?) ?(months|month|m) ?)?" +        r"((?P<weeks>\d+?) ?(weeks|week|W|w) ?)?" +        r"((?P<days>\d+?) ?(days|day|D|d) ?)?" +        r"((?P<hours>\d+?) ?(hours|hour|H|h) ?)?" +        r"((?P<minutes>\d+?) ?(minutes|minute|M) ?)?" +        r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" +    ) + +    async def convert(self, ctx: Context, duration: str) -> datetime: +        """ +        Converts a `duration` string to a datetime object that's `duration` in the future. + +        The converter supports the following symbols for each unit of time: +        - years: `Y`, `y`, `year`, `years` +        - months: `m`, `month`, `months` +        - weeks: `w`, `W`, `week`, `weeks` +        - days: `d`, `D`, `day`, `days` +        - hours: `H`, `h`, `hour`, `hours` +        - minutes: `M`, `minute`, `minutes` +        - seconds: `S`, `s`, `second`, `seconds` + +        The units need to be provided in descending order of magnitude. +        """ +        match = self.duration_parser.fullmatch(duration) +        if not match: +            raise BadArgument(f"`{duration}` is not a valid duration string.") + +        duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} +        delta = relativedelta(**duration_dict)          now = datetime.utcnow() -        if expiry < now: -            expiry = now + (now - expiry) -        return expiry +        return now + delta | 
