diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 270 |
1 files changed, 127 insertions, 143 deletions
diff --git a/bot/converters.py b/bot/converters.py index 1100b502c..cf0496541 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,134 +1,34 @@ import logging -import random -import socket +import re from datetime import datetime from ssl import CertificateError +from typing import Union -import dateparser +import dateutil.parser +import dateutil.tz import discord -from aiohttp import AsyncResolver, ClientConnectorError, ClientSession, TCPConnector +from aiohttp import ClientConnectorError +from dateutil.relativedelta import relativedelta from discord.ext.commands import BadArgument, Context, Converter -from fuzzywuzzy import fuzz - -from bot.constants import DEBUG_MODE, Keys, URLs -from bot.utils import disambiguate log = logging.getLogger(__name__) -class Snake(Converter): - snakes = None - special_cases = None - - async def convert(self, ctx, name): - await self.build_list() - name = name.lower() - - if name == 'python': - return 'Python (programming language)' - - def get_potential(iterable, *, threshold=80): - nonlocal name - potential = [] - - for item in iterable: - original, item = item, item.lower() - - if name == item: - return [original] - - a, b = fuzz.ratio(name, item), fuzz.partial_ratio(name, item) - if a >= threshold or b >= threshold: - potential.append(original) - - return potential - - # Handle special cases - if name.lower() in self.special_cases: - return self.special_cases.get(name.lower(), name.lower()) - - names = {snake['name']: snake['scientific'] for snake in self.snakes} - all_names = names.keys() | names.values() - timeout = len(all_names) * (3 / 4) - - embed = discord.Embed(title='Found multiple choices. Please choose the correct one.', colour=0x59982F) - embed.set_author(name=ctx.author.display_name, icon_url=ctx.author.avatar_url) - - name = await disambiguate(ctx, get_potential(all_names), timeout=timeout, embed=embed) - return names.get(name, name) - - @classmethod - async def build_list(cls): - - headers = {"X-API-KEY": Keys.site_api} - - # Set up the session - if DEBUG_MODE: - http_session = ClientSession( - connector=TCPConnector( - resolver=AsyncResolver(), - family=socket.AF_INET, - verify_ssl=False, - ) - ) - else: - http_session = ClientSession( - connector=TCPConnector( - resolver=AsyncResolver() - ) - ) - - # Get all the snakes - if cls.snakes is None: - response = await http_session.get( - URLs.site_names_api, - params={"get_all": "true"}, - headers=headers - ) - cls.snakes = await response.json() - - # Get the special cases - if cls.special_cases is None: - response = await http_session.get( - URLs.site_special_api, - headers=headers - ) - special_cases = await response.json() - cls.special_cases = {snake['name'].lower(): snake for snake in special_cases} - - # Close the session - http_session.close() - - @classmethod - async def random(cls): - """ - This is stupid. We should find a way to - somehow get the global session into a - global context, so I can get it from here. - :return: - """ - - await cls.build_list() - names = [snake['scientific'] for snake in cls.snakes] - return random.choice(names) - - class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. - This is used to have package names - that correspond to how you would use - the package in your code, e.g. - `import package`. Raises `BadArgument` - if the argument is not a valid Python - identifier, and simply passes through + This is used to have package names that correspond to how you would use the package in your + code, e.g. `import package`. + + Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through the given argument otherwise. """ @staticmethod - async def convert(ctx, argument: str): + async def convert(ctx: Context, argument: str) -> str: + """Checks whether the given string is a valid Python identifier.""" if not argument.isidentifier(): raise BadArgument(f"`{argument}` is not a valid Python identifier") return argument @@ -138,19 +38,20 @@ class ValidURL(Converter): """ Represents a valid webpage URL. - This converter checks whether the given - URL can be reached and requesting it returns - a status code of 200. If not, `BadArgument` - is raised. Otherwise, it simply passes through the given URL. + This converter checks whether the given URL can be reached and requesting it returns a status + code of 200. If not, `BadArgument` is raised. + + Otherwise, it simply passes through the given URL. """ @staticmethod - async def convert(ctx, url: str): + async def convert(ctx: Context, url: str) -> str: + """This converter checks whether the given URL can be reached with a status code of 200.""" try: async with ctx.bot.http_session.get(url) as resp: if resp.status != 200: raise BadArgument( - f"HTTP GET on `{url}` returned status `{resp.status_code}`, expected 200" + f"HTTP GET on `{url}` returned status `{resp.status}`, expected 200" ) except CertificateError: if url.startswith('https'): @@ -166,26 +67,28 @@ class ValidURL(Converter): class InfractionSearchQuery(Converter): - """ - A converter that checks if the argument is a Discord user, and if not, falls back to a string. - """ + """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx, arg): + async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") - return await ctx.bot.get_user_info(maybe_snowflake) + return await ctx.bot.fetch_user(maybe_snowflake) except (discord.NotFound, discord.HTTPException): return arg class Subreddit(Converter): - """ - Forces a string to begin with "r/" and checks if it's a valid subreddit. - """ + """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" @staticmethod - async def convert(ctx, sub: str): + async def convert(ctx: Context, sub: str) -> str: + """ + Force sub to begin with "r/" and check if it's a valid subreddit. + + If sub is a valid subreddit, return it prepended with "r/" + """ sub = sub.lower() if not sub.startswith("r/"): @@ -206,9 +109,21 @@ class Subreddit(Converter): class TagNameConverter(Converter): + """ + Ensure that a proposed tag name is valid. + + Valid tag names meet the following conditions: + * All ASCII characters + * Has at least one non-whitespace character + * Not solely numeric + * Shorter than 127 characters + """ + @staticmethod - async def convert(ctx: Context, tag_name: str): - def is_number(value): + async def convert(ctx: Context, tag_name: str) -> str: + """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" + def is_number(value: str) -> bool: + """Check to see if the input string is numeric.""" try: float(value) except ValueError: @@ -245,8 +160,15 @@ class TagNameConverter(Converter): class TagContentConverter(Converter): + """Ensure proposed tag content is not empty and contains at least one non-whitespace character.""" + @staticmethod - async def convert(ctx: Context, tag_content: str): + async def convert(ctx: Context, tag_content: str) -> str: + """ + Ensure tag_content is non-empty and contains at least one non-whitespace character. + + If tag_content is valid, return the stripped version. + """ tag_content = tag_content.strip() # The tag contents should not be empty, or filled with whitespace. @@ -258,20 +180,82 @@ class TagContentConverter(Converter): return tag_content -class ExpirationDate(Converter): - DATEPARSER_SETTINGS = { - 'PREFER_DATES_FROM': 'future', - 'TIMEZONE': 'UTC', - 'TO_TIMEZONE': 'UTC' - } +class Duration(Converter): + """Convert duration strings into UTC datetime.datetime objects.""" - async def convert(self, ctx, expiration_string: str): - expiry = dateparser.parse(expiration_string, settings=self.DATEPARSER_SETTINGS) - if expiry is None: - raise BadArgument(f"Failed to parse expiration date from `{expiration_string}`") + duration_parser = re.compile( + r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" + r"((?P<months>\d+?) ?(months|month|m) ?)?" + r"((?P<weeks>\d+?) ?(weeks|week|W|w) ?)?" + r"((?P<days>\d+?) ?(days|day|D|d) ?)?" + r"((?P<hours>\d+?) ?(hours|hour|H|h) ?)?" + r"((?P<minutes>\d+?) ?(minutes|minute|M) ?)?" + r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" + ) + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + The units need to be provided in descending order of magnitude. + """ + match = self.duration_parser.fullmatch(duration) + if not match: + raise BadArgument(f"`{duration}` is not a valid duration string.") + + duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} + delta = relativedelta(**duration_dict) now = datetime.utcnow() - if expiry < now: - expiry = now + (now - expiry) - return expiry + return now + delta + + +class ISODateTime(Converter): + """Converts an ISO-8601 datetime string into a datetime.datetime.""" + + async def convert(self, ctx: Context, datetime_string: str) -> datetime: + """ + Converts a ISO-8601 `datetime_string` into a `datetime.datetime` object. + + The converter is flexible in the formats it accepts, as it uses the `isoparse` method of + `dateutil.parser`. In general, it accepts datetime strings that start with a date, + optionally followed by a time. Specifying a timezone offset in the datetime string is + supported, but the `datetime` object will be converted to UTC and will be returned without + `tzinfo` as a timezone-unaware `datetime` object. + + See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse + + Formats that are guaranteed to be valid by our tests are: + + - `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` + - `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM` + - `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM` + - `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH` + - `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS` + - `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM` + - `YYYY-mm-dd` + - `YYYY-mm` + - `YYYY` + + Note: ISO-8601 specifies a `T` as the separator between the date and the time part of the + datetime string. The converter accepts both a `T` and a single space character. + """ + try: + dt = dateutil.parser.isoparse(datetime_string) + except ValueError: + raise BadArgument(f"`{datetime_string}` is not a valid ISO-8601 datetime string") + + if dt.tzinfo: + dt = dt.astimezone(dateutil.tz.UTC) + dt = dt.replace(tzinfo=None) + + return dt |