aboutsummaryrefslogtreecommitdiffstats
path: root/bot/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot/converters.py')
-rw-r--r--bot/converters.py98
1 files changed, 73 insertions, 25 deletions
diff --git a/bot/converters.py b/bot/converters.py
index 546f6e8f4..5aaff4e76 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import logging
import re
import typing as t
from datetime import datetime
@@ -11,20 +10,23 @@ import dateutil.tz
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, UserConverter
+from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter
from discord.utils import DISCORD_EPOCH, escape_markdown, snowflake_time
from bot import exts
from bot.api import ResponseCodeError
from bot.constants import URLs
+from bot.errors import InvalidInfraction
from bot.exts.info.doc import _inventory_parser
+from bot.log import get_logger
from bot.utils.extensions import EXTENSIONS, unqualify
from bot.utils.regex import INVITE_RE
from bot.utils.time import parse_duration_string
+
if t.TYPE_CHECKING:
from bot.exts.info.source import SourceType
-log = logging.getLogger(__name__)
+log = get_logger(__name__)
DISCORD_EPOCH_DT = datetime.utcfromtimestamp(DISCORD_EPOCH / 1000)
RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
@@ -69,10 +71,10 @@ class ValidDiscordServerInvite(Converter):
async def convert(self, ctx: Context, server_invite: str) -> dict:
"""Check whether the string is a valid Discord server invite."""
- invite_code = INVITE_RE.search(server_invite)
+ invite_code = INVITE_RE.match(server_invite)
if invite_code:
response = await ctx.bot.http_session.get(
- f"{URLs.discord_invite_api}/{invite_code[1]}"
+ f"{URLs.discord_invite_api}/{invite_code.group('invite')}"
)
if response.status != 404:
invite_data = await response.json()
@@ -234,11 +236,16 @@ class Inventory(Converter):
async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:
"""Convert url to Intersphinx inventory URL."""
await ctx.trigger_typing()
- if (inventory := await _inventory_parser.fetch_inventory(url)) is None:
- raise BadArgument(
- f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
- )
- return url, inventory
+ try:
+ inventory = await _inventory_parser.fetch_inventory(url)
+ except _inventory_parser.InvalidHeaderError:
+ raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.")
+ else:
+ if inventory is None:
+ raise BadArgument(
+ f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
+ )
+ return url, inventory
class Snowflake(IDConverter):
@@ -266,7 +273,7 @@ class Snowflake(IDConverter):
snowflake = int(arg)
try:
- time = snowflake_time(snowflake)
+ time = snowflake_time(snowflake).replace(tzinfo=None)
except (OverflowError, OSError) as e:
# Not sure if this can ever even happen, but let's be safe.
raise BadArgument(f"{error}: {e}")
@@ -409,7 +416,8 @@ class Age(DurationDelta):
class OffTopicName(Converter):
"""A converter that ensures an added off-topic name is valid."""
- ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+ ALLOWED_CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>"
+ TRANSLATED_CHARACTERS = "๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-๏ผœ๏ผž"
@classmethod
def translate_name(cls, name: str, *, from_unicode: bool = True) -> str:
@@ -419,9 +427,9 @@ class OffTopicName(Converter):
If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text.
"""
if from_unicode:
- table = str.maketrans(cls.ALLOWED_CHARACTERS, '๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-')
+ table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS)
else:
- table = str.maketrans('๐– ๐–ก๐–ข๐–ฃ๐–ค๐–ฅ๐–ฆ๐–ง๐–จ๐–ฉ๐–ช๐–ซ๐–ฌ๐–ญ๐–ฎ๐–ฏ๐–ฐ๐–ฑ๐–ฒ๐–ณ๐–ด๐–ต๐–ถ๐–ท๐–ธ๐–นวƒ๏ผŸโ€™โ€™-', cls.ALLOWED_CHARACTERS)
+ table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS)
return name.translate(table)
@@ -513,22 +521,51 @@ class HushDurationConverter(Converter):
return duration
-class UserMentionOrID(UserConverter):
+def _is_an_unambiguous_user_argument(argument: str) -> bool:
+ """Check if the provided argument is a user mention, user id, or username (name#discrim)."""
+ has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument))
+
+ # Check to see if the author passed a username (a discriminator exists)
+ argument = argument.removeprefix('@')
+ has_username = len(argument) > 5 and argument[-5] == '#'
+
+ return has_id_or_mention or has_username
+
+
+AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format"
+ " `name#discriminator`.")
+
+
+class UnambiguousUser(UserConverter):
"""
- Converts to a `discord.User`, but only if a mention or userID is provided.
+ Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided.
- Unlike the default `UserConverter`, it doesn't allow conversion from a name or name#descrim.
- This is useful in cases where that lookup strategy would lead to ambiguity.
+ Unlike the default `UserConverter`, it doesn't allow conversion from a name.
+ This is useful in cases where that lookup strategy would lead to too much ambiguity.
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
- """Convert the `arg` to a `discord.User`."""
- match = self._get_id_match(argument) or RE_USER_MENTION.match(argument)
+ """Convert the `argument` to a `discord.User`."""
+ if _is_an_unambiguous_user_argument(argument):
+ return await super().convert(ctx, argument)
+ else:
+ raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
+
- if match is not None:
+class UnambiguousMember(MemberConverter):
+ """
+ Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided.
+
+ Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname.
+ This is useful in cases where that lookup strategy would lead to too much ambiguity.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> discord.Member:
+ """Convert the `argument` to a `discord.Member`."""
+ if _is_an_unambiguous_user_argument(argument):
return await super().convert(ctx, argument)
else:
- raise BadArgument(f"`{argument}` is not a User mention or a User ID.")
+ raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
class Infraction(Converter):
@@ -547,7 +584,7 @@ class Infraction(Converter):
"ordering": "-inserted_at"
}
- infractions = await ctx.bot.api_client.get("bot/infractions", params=params)
+ infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params)
if not infractions:
raise BadArgument(
@@ -557,7 +594,16 @@ class Infraction(Converter):
return infractions[0]
else:
- return await ctx.bot.api_client.get(f"bot/infractions/{arg}")
+ try:
+ return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded")
+ except ResponseCodeError as e:
+ if e.status == 404:
+ raise InvalidInfraction(
+ converter=Infraction,
+ original=e,
+ infraction_arg=arg
+ )
+ raise e
if t.TYPE_CHECKING:
@@ -576,8 +622,10 @@ if t.TYPE_CHECKING:
OffTopicName = str # noqa: F811
ISODateTime = datetime # noqa: F811
HushDurationConverter = int # noqa: F811
- UserMentionOrID = discord.User # noqa: F811
+ UnambiguousUser = discord.User # noqa: F811
+ UnambiguousMember = discord.Member # noqa: F811
Infraction = t.Optional[dict] # noqa: F811
Expiry = t.Union[Duration, ISODateTime]
MemberOrUser = t.Union[discord.Member, discord.User]
+UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser]