diff options
Diffstat (limited to 'bot/converters.py')
-rw-r--r-- | bot/converters.py | 86 |
1 files changed, 52 insertions, 34 deletions
diff --git a/bot/converters.py b/bot/converters.py index 0d9a519df..3bf05cfb3 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -15,7 +15,9 @@ from discord.utils import DISCORD_EPOCH, snowflake_time from bot.api import ResponseCodeError from bot.constants import URLs +from bot.exts.info.doc import _inventory_parser from bot.utils.regex import INVITE_RE +from bot.utils.time import parse_duration_string log = logging.getLogger(__name__) @@ -126,22 +128,20 @@ class ValidFilterListType(Converter): return list_type -class ValidPythonIdentifier(Converter): +class PackageName(Converter): """ - A converter that checks whether the given string is a valid Python identifier. + A converter that checks whether the given string is a valid package name. - This is used to have package names that correspond to how you would use the package in your - code, e.g. `import package`. - - Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through - the given argument otherwise. + Package names are used for stats and are restricted to the a-z and _ characters. """ - @staticmethod - async def convert(ctx: Context, argument: str) -> str: - """Checks whether the given string is a valid Python identifier.""" - if not argument.isidentifier(): - raise BadArgument(f"`{argument}` is not a valid Python identifier") + PACKAGE_NAME_RE = re.compile(r"[^a-z0-9_]") + + @classmethod + async def convert(cls, ctx: Context, argument: str) -> str: + """Checks whether the given string is a valid package name.""" + if cls.PACKAGE_NAME_RE.search(argument): + raise BadArgument("The provided package name is not valid; please only use the _, 0-9, and a-z characters.") return argument @@ -177,6 +177,27 @@ class ValidURL(Converter): return url +class Inventory(Converter): + """ + Represents an Intersphinx inventory URL. + + This converter checks whether intersphinx accepts the given inventory URL, and raises + `BadArgument` if that is not the case or if the url is unreachable. + + Otherwise, it returns the url and the fetched inventory dict in a tuple. + """ + + @staticmethod + async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]: + """Convert url to Intersphinx inventory URL.""" + await ctx.trigger_typing() + if (inventory := await _inventory_parser.fetch_inventory(url)) is None: + raise BadArgument( + f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." + ) + return url, inventory + + class Snowflake(IDConverter): """ Converts to an int if the argument is a valid Discord snowflake. @@ -301,16 +322,6 @@ class TagContentConverter(Converter): class DurationDelta(Converter): """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" - duration_parser = re.compile( - r"((?P<years>\d+?) ?(years|year|Y|y) ?)?" - r"((?P<months>\d+?) ?(months|month|m) ?)?" - r"((?P<weeks>\d+?) ?(weeks|week|W|w) ?)?" - r"((?P<days>\d+?) ?(days|day|D|d) ?)?" - r"((?P<hours>\d+?) ?(hours|hour|H|h) ?)?" - r"((?P<minutes>\d+?) ?(minutes|minute|M) ?)?" - r"((?P<seconds>\d+?) ?(seconds|second|S|s))?" - ) - async def convert(self, ctx: Context, duration: str) -> relativedelta: """ Converts a `duration` string to a relativedelta object. @@ -326,13 +337,9 @@ class DurationDelta(Converter): The units need to be provided in descending order of magnitude. """ - match = self.duration_parser.fullmatch(duration) - if not match: + if not (delta := parse_duration_string(duration)): raise BadArgument(f"`{duration}` is not a valid duration string.") - duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default=0).items()} - delta = relativedelta(**duration_dict) - return delta @@ -357,27 +364,38 @@ class Duration(DurationDelta): class OffTopicName(Converter): """A converter that ensures an added off-topic name is valid.""" + ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" + + @classmethod + def translate_name(cls, name: str, *, from_unicode: bool = True) -> str: + """ + Translates `name` into a format that is allowed in discord channel names. + + If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text. + """ + if from_unicode: + table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-') + else: + table = str.maketrans('๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-', cls.ALLOWED_CHARACTERS) + + return name.translate(table) + async def convert(self, ctx: Context, argument: str) -> str: """Attempt to replace any invalid characters with their approximate Unicode equivalent.""" - allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-" - # Chain multiple words to a single one argument = "-".join(argument.split()) if not (2 <= len(argument) <= 96): raise BadArgument("Channel name must be between 2 and 96 chars long") - elif not all(c.isalnum() or c in allowed_characters for c in argument): + elif not all(c.isalnum() or c in self.ALLOWED_CHARACTERS for c in argument): raise BadArgument( "Channel name must only consist of " "alphanumeric characters, minus signs or apostrophes." ) # Replace invalid characters with unicode alternatives. - table = str.maketrans( - allowed_characters, '๐ ๐ก๐ข๐ฃ๐ค๐ฅ๐ฆ๐ง๐จ๐ฉ๐ช๐ซ๐ฌ๐ญ๐ฎ๐ฏ๐ฐ๐ฑ๐ฒ๐ณ๐ด๐ต๐ถ๐ท๐ธ๐นว๏ผโโ-' - ) - return argument.translate(table) + return self.translate_name(argument) class ISODateTime(Converter): |