aboutsummaryrefslogtreecommitdiffstats
path: root/bot/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/converters.py')
-rw-r--r--bot/converters.py332
1 files changed, 179 insertions, 153 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 2a3943831..559e759e1 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,8 +1,8 @@
-import logging
+from __future__ import annotations
+
import re
import typing as t
-from datetime import datetime
-from functools import partial
+from datetime import datetime, timezone
from ssl import CertificateError
import dateutil.parser
@@ -10,18 +10,26 @@ import dateutil.tz
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
-from discord.utils import DISCORD_EPOCH, snowflake_time
+from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter
+from discord.utils import escape_markdown, snowflake_time
+from bot import exts
from bot.api import ResponseCodeError
from bot.constants import URLs
+from bot.errors import InvalidInfraction
from bot.exts.info.doc import _inventory_parser
+from bot.exts.info.tags import TagIdentifier
+from bot.log import get_logger
+from bot.utils.extensions import EXTENSIONS, unqualify
from bot.utils.regex import INVITE_RE
from bot.utils.time import parse_duration_string
-log = logging.getLogger(__name__)
+if t.TYPE_CHECKING:
+ from bot.exts.info.source import SourceType
+
+log = get_logger(__name__)
-DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)
+DISCORD_EPOCH_DT = snowflake_time(0)
RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
@@ -64,10 +72,10 @@ class ValidDiscordServerInvite(Converter):
async def convert(self, ctx: Context, server_invite: str) -> dict:
"""Check whether the string is a valid Discord server invite."""
- invite_code = INVITE_RE.search(server_invite)
+ invite_code = INVITE_RE.match(server_invite)
if invite_code:
response = await ctx.bot.http_session.get(
- f"{URLs.discord_invite_api}/{invite_code[1]}"
+ f"{URLs.discord_invite_api}/{invite_code.group('invite')}"
)
if response.status != 404:
invite_data = await response.json()
@@ -128,6 +136,44 @@ class ValidFilterListType(Converter):
return list_type
+class Extension(Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if argument == "*" or argument == "**":
+ return argument
+
+ argument = argument.lower()
+
+ if argument in EXTENSIONS:
+ return argument
+ elif (qualified_arg := f"{exts.__name__}.{argument}") in EXTENSIONS:
+ return qualified_arg
+
+ matches = []
+ for ext in EXTENSIONS:
+ if argument == unqualify(ext):
+ matches.append(ext)
+
+ if len(matches) > 1:
+ matches.sort()
+ names = "\n".join(matches)
+ raise BadArgument(
+ f":x: `{argument}` is an ambiguous extension name. "
+ f"Please use one of the following fully-qualified names.```\n{names}```"
+ )
+ elif matches:
+ return matches[0]
+ else:
+ raise BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
class PackageName(Converter):
"""
A converter that checks whether the given string is a valid package name.
@@ -191,11 +237,16 @@ class Inventory(Converter):
async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:
"""Convert url to Intersphinx inventory URL."""
await ctx.trigger_typing()
- if (inventory := await _inventory_parser.fetch_inventory(url)) is None:
- raise BadArgument(
- f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
- )
- return url, inventory
+ try:
+ inventory = await _inventory_parser.fetch_inventory(url)
+ except _inventory_parser.InvalidHeaderError:
+ raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.")
+ else:
+ if inventory is None:
+ raise BadArgument(
+ f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
+ )
+ return url, inventory
class Snowflake(IDConverter):
@@ -230,64 +281,43 @@ class Snowflake(IDConverter):
if time < DISCORD_EPOCH_DT:
raise BadArgument(f"{error}: timestamp is before the Discord epoch.")
- elif (datetime.utcnow() - time).days < -1:
+ elif (datetime.now(timezone.utc) - time).days < -1:
raise BadArgument(f"{error}: timestamp is too far into the future.")
return snowflake
-class TagNameConverter(Converter):
- """
- Ensure that a proposed tag name is valid.
-
- Valid tag names meet the following conditions:
- * All ASCII characters
- * Has at least one non-whitespace character
- * Not solely numeric
- * Shorter than 127 characters
- """
+class SourceConverter(Converter):
+ """Convert an argument into a help command, tag, command, or cog."""
@staticmethod
- async def convert(ctx: Context, tag_name: str) -> str:
- """Lowercase & strip whitespace from proposed tag_name & ensure it's valid."""
- tag_name = tag_name.lower().strip()
-
- # The tag name has at least one invalid character.
- if ascii(tag_name)[1:-1] != tag_name:
- raise BadArgument("Don't be ridiculous, you can't use that character!")
-
- # The tag name is either empty, or consists of nothing but whitespace.
- elif not tag_name:
- raise BadArgument("Tag names should not be empty, or filled with whitespace.")
-
- # The tag name is longer than 127 characters.
- elif len(tag_name) > 127:
- raise BadArgument("Are you insane? That's way too long!")
-
- # The tag name is ascii but does not contain any letters.
- elif not any(character.isalpha() for character in tag_name):
- raise BadArgument("Tag names must contain at least one letter.")
+ async def convert(ctx: Context, argument: str) -> SourceType:
+ """Convert argument into source object."""
+ if argument.lower() == "help":
+ return ctx.bot.help_command
- return tag_name
+ cog = ctx.bot.get_cog(argument)
+ if cog:
+ return cog
+ cmd = ctx.bot.get_command(argument)
+ if cmd:
+ return cmd
-class TagContentConverter(Converter):
- """Ensure proposed tag content is not empty and contains at least one non-whitespace character."""
+ tags_cog = ctx.bot.get_cog("Tags")
+ show_tag = True
- @staticmethod
- async def convert(ctx: Context, tag_content: str) -> str:
- """
- Ensure tag_content is non-empty and contains at least one non-whitespace character.
-
- If tag_content is valid, return the stripped version.
- """
- tag_content = tag_content.strip()
-
- # The tag contents should not be empty, or filled with whitespace.
- if not tag_content:
- raise BadArgument("Tag contents should not be empty, or filled with whitespace.")
+ if not tags_cog:
+ show_tag = False
+ else:
+ identifier = TagIdentifier.from_string(argument.lower())
+ if identifier in tags_cog.tags:
+ return identifier
+ escaped_arg = escape_markdown(argument)
- return tag_content
+ raise BadArgument(
+ f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
+ )
class DurationDelta(Converter):
@@ -324,7 +354,7 @@ class Duration(DurationDelta):
The converter supports the same symbols for each unit of time as its parent class.
"""
delta = await super().convert(ctx, duration)
- now = datetime.utcnow()
+ now = datetime.now(timezone.utc)
try:
return now + delta
@@ -332,10 +362,29 @@ class Duration(DurationDelta):
raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
+class Age(DurationDelta):
+ """Convert duration strings into UTC datetime.datetime objects."""
+
+ async def convert(self, ctx: Context, duration: str) -> datetime:
+ """
+ Converts a `duration` string to a datetime object that's `duration` in the past.
+
+ The converter supports the same symbols for each unit of time as its parent class.
+ """
+ delta = await super().convert(ctx, duration)
+ now = datetime.now(timezone.utc)
+
+ try:
+ return now - delta
+ except (ValueError, OverflowError):
+ raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
+
+
class OffTopicName(Converter):
"""A converter that ensures an added off-topic name is valid."""
- ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+ ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>"
+ TRANSLATED_CHARACTERS = "๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-๏ผœ๏ผž"
@classmethod
def translate_name(cls, name: str, *, from_unicode: bool = True) -> str:
@@ -345,9 +394,9 @@ class OffTopicName(Converter):
If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text.
"""
if from_unicode:
- table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-')
+ table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS)
else:
- table = str.maketrans('๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-', cls.ALLOWED_CHARACTERS)
+ table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS)
return name.translate(table)
@@ -379,8 +428,8 @@ class ISODateTime(Converter):
The converter is flexible in the formats it accepts, as it uses the `isoparse` method of
`dateutil.parser`. In general, it accepts datetime strings that start with a date,
optionally followed by a time. Specifying a timezone offset in the datetime string is
- supported, but the `datetime` object will be converted to UTC and will be returned without
- `tzinfo` as a timezone-unaware `datetime` object.
+ supported, but the `datetime` object will be converted to UTC. If no timezone is specified, the datetime will
+ be assumed to be in UTC already. In all cases, the returned object will have the UTC timezone.
See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse
@@ -406,7 +455,8 @@ class ISODateTime(Converter):
if dt.tzinfo:
dt = dt.astimezone(dateutil.tz.UTC)
- dt = dt.replace(tzinfo=None)
+ else: # Without a timezone, assume it represents UTC.
+ dt = dt.replace(tzinfo=dateutil.tz.UTC)
return dt
@@ -416,11 +466,11 @@ class HushDurationConverter(Converter):
MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)")
- async def convert(self, ctx: Context, argument: str) -> t.Optional[int]:
+ async def convert(self, ctx: Context, argument: str) -> int:
"""
Convert `argument` to a duration that's max 15 minutes or None.
- If `"forever"` is passed, None is returned; otherwise an int of the extracted time.
+ If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time.
Accepted formats are:
* <duration>,
* <duration>m,
@@ -428,7 +478,7 @@ class HushDurationConverter(Converter):
* forever.
"""
if argument == "forever":
- return None
+ return -1
match = self.MINUTES_RE.match(argument)
if not match:
raise BadArgument(f"{argument} is not a valid minutes duration.")
@@ -439,103 +489,51 @@ class HushDurationConverter(Converter):
return duration
-def proxy_user(user_id: str) -> discord.Object:
- """
- Create a proxy user object from the given id.
+def _is_an_unambiguous_user_argument(argument: str) -> bool:
+ """Check if the provided argument is a user mention, user id, or username (name#discrim)."""
+ has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument))
- Used when a Member or User object cannot be resolved.
- """
- log.trace(f"Attempting to create a proxy user for the user id {user_id}.")
+ # Check to see if the author passed a username (a discriminator exists)
+ argument = argument.removeprefix('@')
+ has_username = len(argument) > 5 and argument[-5] == '#'
- try:
- user_id = int(user_id)
- except ValueError:
- log.debug(f"Failed to create proxy user {user_id}: could not convert to int.")
- raise BadArgument(f"User ID `{user_id}` is invalid - could not convert to an integer.")
+ return has_id_or_mention or has_username
- user = discord.Object(user_id)
- user.mention = user.id
- user.display_name = f"<@{user.id}>"
- user.avatar_url_as = lambda static_format: None
- user.bot = False
- return user
+AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format"
+ " `name#discriminator`.")
-class UserMentionOrID(UserConverter):
+class UnambiguousUser(UserConverter):
"""
- Converts to a `discord.User`, but only if a mention or userID is provided.
+ Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided.
- Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim.
- This is useful in cases where that lookup strategy would lead to ambiguity.
+ Unlike the default `UserConverter`, it doesn't allow conversion from a name.
+ This is useful in cases where that lookup strategy would lead to too much ambiguity.
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
- """Convert the `arg` to a `discord.User`."""
- match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)
-
- if match is not None:
+ """Convert the `argument` to a `discord.User`."""
+ if _is_an_unambiguous_user_argument(argument):
return await super().convert(ctx, argument)
else:
- raise BadArgument(f"`{argument}` is not a User mention or a User ID.")
+ raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
-class FetchedUser(UserConverter):
+class UnambiguousMember(MemberConverter):
"""
- Converts to a `discord.User` or, if it fails, a `discord.Object`.
-
- Unlike the default `UserConverter`, which only does lookups via the global user cache, this
- converter attempts to fetch the user via an API call to Discord when the using the cache is
- unsuccessful.
-
- If the fetch also fails and the error doesn't imply the user doesn't exist, then a
- `discord.Object` is returned via the `user_proxy` converter.
+ Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided.
- The lookup strategy is as follows (in order):
-
- 1. Lookup by ID.
- 2. Lookup by mention.
- 3. Lookup by name#discrim
- 4. Lookup by name
- 5. Lookup via API
- 6. Create a proxy user with discord.Object
- """
-
- async def convert(self, ctx: Context, arg: str) -> t.Union[discord.User, discord.Object]:
- """Convert the `arg` to a `discord.User` or `discord.Object`."""
- try:
- return await super().convert(ctx, arg)
- except BadArgument:
- pass
-
- try:
- user_id = int(arg)
- log.trace(f"Fetching user {user_id}...")
- return await ctx.bot.fetch_user(user_id)
- except ValueError:
- log.debug(f"Failed to fetch user {arg}: could not convert to int.")
- raise BadArgument(f"The provided argument can't be turned into integer: `{arg}`")
- except discord.HTTPException as e:
- # If the Discord error isn't `Unknown user`, return a proxy instead
- if e.code != 10013:
- log.info(f"Failed to fetch user, returning a proxy instead: status {e.status}")
- return proxy_user(arg)
-
- log.debug(f"Failed to fetch user {arg}: user does not exist.")
- raise BadArgument(f"User `{arg}` does not exist")
-
-
-def _snowflake_from_regex(pattern: t.Pattern, arg: str) -> int:
- """
- Extract the snowflake from `arg` using a regex `pattern` and return it as an int.
-
- The snowflake is expected to be within the first capture group in `pattern`.
+ Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname.
+ This is useful in cases where that lookup strategy would lead to too much ambiguity.
"""
- match = pattern.match(arg)
- if not match:
- raise BadArgument(f"Mention {str!r} is invalid.")
- return int(match.group(1))
+ async def convert(self, ctx: Context, argument: str) -> discord.Member:
+ """Convert the `argument` to a `discord.Member`."""
+ if _is_an_unambiguous_user_argument(argument):
+ return await super().convert(ctx, argument)
+ else:
+ raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
class Infraction(Converter):
@@ -554,7 +552,7 @@ class Infraction(Converter):
"ordering": "-inserted_at"
}
- infractions = await ctx.bot.api_client.get("bot/infractions", params=params)
+ infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params)
if not infractions:
raise BadArgument(
@@ -564,9 +562,37 @@ class Infraction(Converter):
return infractions[0]
else:
- return await ctx.bot.api_client.get(f"bot/infractions/{arg}")
-
+ try:
+ return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded")
+ except ResponseCodeError as e:
+ if e.status == 404:
+ raise InvalidInfraction(
+ converter=Infraction,
+ original=e,
+ infraction_arg=arg
+ )
+ raise e
+
+
+if t.TYPE_CHECKING:
+ ValidDiscordServerInvite = dict # noqa: F811
+ ValidFilterListType = str # noqa: F811
+ Extension = str # noqa: F811
+ PackageName = str # noqa: F811
+ ValidURL = str # noqa: F811
+ Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811
+ Snowflake = int # noqa: F811
+ SourceConverter = SourceType # noqa: F811
+ DurationDelta = relativedelta # noqa: F811
+ Duration = datetime # noqa: F811
+ Age = datetime # noqa: F811
+ OffTopicName = str # noqa: F811
+ ISODateTime = datetime # noqa: F811
+ HushDurationConverter = int # noqa: F811
+ UnambiguousUser = discord.User # noqa: F811
+ UnambiguousMember = discord.Member # noqa: F811
+ Infraction = t.Optional[dict] # noqa: F811
Expiry = t.Union[Duration, ISODateTime]
-FetchedMember = t.Union[discord.Member, FetchedUser]
-UserMention = partial(_snowflake_from_regex, RE_USER_MENTION)
+MemberOrUser = t.Union[discord.Member, discord.User]
+UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser]