diff options
Diffstat (limited to 'bot/utils')
-rw-r--r-- | bot/utils/__init__.py | 19 | ||||
-rw-r--r-- | bot/utils/checks.py | 9 | ||||
-rw-r--r-- | bot/utils/converters.py | 9 | ||||
-rw-r--r-- | bot/utils/decorators.py | 28 | ||||
-rw-r--r-- | bot/utils/exceptions.py | 5 | ||||
-rw-r--r-- | bot/utils/messages.py | 25 | ||||
-rw-r--r-- | bot/utils/pagination.py | 45 | ||||
-rw-r--r-- | bot/utils/randomization.py | 8 | ||||
-rw-r--r-- | bot/utils/time.py | 13 |
9 files changed, 74 insertions, 87 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 91682dbc..ddc2d111 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -3,8 +3,7 @@ import contextlib import re import string from collections.abc import Iterable -from datetime import datetime -from typing import Optional +from datetime import UTC, datetime import discord from discord.ext.commands import BadArgument, Context @@ -27,8 +26,7 @@ def resolve_current_month() -> Month: """ if Client.month_override is not None: return Month(Client.month_override) - else: - return Month(datetime.utcnow().month) + return Month(datetime.now(tz=UTC).month) async def disambiguate( @@ -38,7 +36,7 @@ async def disambiguate( timeout: float = 30, entries_per_page: int = 20, empty: bool = False, - embed: Optional[discord.Embed] = None + embed: discord.Embed | None = None ) -> str: """ Has the user choose between multiple entries in case one could not be chosen automatically. @@ -130,9 +128,9 @@ def replace_many( assert var == "That WAS a sentence" """ if ignore_case: - replacements = dict( - (word.lower(), replacement) for word, replacement in replacements.items() - ) + replacements = { + word.lower(): replacement for word, replacement in replacements.items() + } words_to_replace = sorted(replacements, key=lambda s: (-len(s), s)) @@ -152,10 +150,9 @@ def replace_many( cleaned_word = word.translate(str.maketrans("", "", string.punctuation)) if cleaned_word.isupper(): return replacement.upper() - elif cleaned_word[0].isupper(): + if cleaned_word[0].isupper(): return replacement.capitalize() - else: - return replacement.lower() + return replacement.lower() return regex.sub(_repl, sentence) diff --git a/bot/utils/checks.py b/bot/utils/checks.py index f21d2ddd..418bb7ad 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -1,7 +1,6 @@ import datetime import logging -from collections.abc import Container, Iterable -from typing import Callable, Optional +from collections.abc import Callable, Container, Iterable from discord.ext.commands import ( BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping @@ -15,7 +14,7 @@ log = logging.getLogger(__name__) class InWhitelistCheckFailure(CheckFailure): """Raised when the `in_whitelist` check fails.""" - def __init__(self, redirect_channel: Optional[int]): + def __init__(self, redirect_channel: int | None): self.redirect_channel = redirect_channel if redirect_channel: @@ -33,7 +32,7 @@ def in_whitelist_check( channels: Container[int] = (), categories: Container[int] = (), roles: Container[int] = (), - redirect: Optional[int] = constants.Channels.sir_lancebot_playground, + redirect: int | None = constants.Channels.sir_lancebot_playground, fail_silently: bool = False, ) -> bool: """ @@ -153,7 +152,7 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy return # Cooldown logic, taken from discord.py internals. - current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp() + current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp() bucket = buckets.get_bucket(ctx.message) retry_after = bucket.update_rate_limit(current) if retry_after: diff --git a/bot/utils/converters.py b/bot/utils/converters.py index 7227a406..6111b87d 100644 --- a/bot/utils/converters.py +++ b/bot/utils/converters.py @@ -1,5 +1,4 @@ -from datetime import datetime -from typing import Union +from datetime import UTC, datetime import discord from discord.ext import commands @@ -47,7 +46,7 @@ class CoordinateConverter(commands.Converter): return x, y -SourceType = Union[commands.Command, commands.Cog] +SourceType = commands.Command | commands.Cog class SourceConverter(commands.Converter): @@ -73,12 +72,12 @@ class DateConverter(commands.Converter): """Parse SOL or earth date (in format YYYY-MM-DD) into `int` or `datetime`. When invalid input, raise error.""" @staticmethod - async def convert(ctx: commands.Context, argument: str) -> Union[int, datetime]: + async def convert(ctx: commands.Context, argument: str) -> int | datetime: """Parse date (SOL or earth) into `datetime` or `int`. When invalid value, raise error.""" if argument.isdecimal(): return int(argument) try: - date = datetime.strptime(argument, "%Y-%m-%d") + date = datetime.strptime(argument, "%Y-%m-%d").replace(tzinfo=UTC) except ValueError: raise commands.BadArgument( f"Can't convert `{argument}` to `datetime` in format `YYYY-MM-DD` or `int` in SOL." diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 442eb841..1cbad504 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -3,9 +3,8 @@ import functools import logging import random from asyncio import Lock -from collections.abc import Container +from collections.abc import Callable, Container from functools import wraps -from typing import Callable, Optional, Union from weakref import WeakValueDictionary from discord import Colour, Embed @@ -24,16 +23,12 @@ log = logging.getLogger(__name__) class InChannelCheckFailure(CheckFailure): """Check failure when the user runs a command in a non-whitelisted channel.""" - pass - class InMonthCheckFailure(CheckFailure): """Check failure for when a command is invoked outside of its allowed month.""" - pass - -def seasonal_task(*allowed_months: Month, sleep_time: Union[float, int] = ONE_DAY) -> Callable: +def seasonal_task(*allowed_months: Month, sleep_time: float | int = ONE_DAY) -> Callable: """ Perform the decorated method periodically in `allowed_months`. @@ -79,8 +74,8 @@ def in_month_listener(*allowed_months: Month) -> Callable: if current_month in allowed_months: # Propagate return value although it should always be None return await listener(*args, **kwargs) - else: - log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month!s}") + log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month!s}") + return None return guarded_listener return decorator @@ -101,8 +96,7 @@ def in_month_command(*allowed_months: Month) -> Callable: ) if can_run: return True - else: - raise InMonthCheckFailure(f"Command can only be used in {human_months(allowed_months)}") + raise InMonthCheckFailure(f"Command can only be used in {human_months(allowed_months)}") return commands.check(predicate) @@ -201,13 +195,13 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo # Determine which command's overrides we will use. Group commands will # inherit from their parents if they don't define their own overrides - overridden_command: Optional[commands.Command] = None + overridden_command: commands.Command | None = None for command in [ctx.command, *ctx.command.parents]: if hasattr(command.callback, "override"): overridden_command = command break if overridden_command is not None: - log.debug(f'Command {overridden_command} has overrides') + log.debug(f"Command {overridden_command} has overrides") if overridden_command is not ctx.command: log.debug( f"Command '{ctx.command.qualified_name}' inherited overrides " @@ -319,7 +313,7 @@ def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, ** return inner -def locked() -> Optional[Callable]: +def locked() -> Callable | None: """ Allows the user to only run one instance of the decorated command at a time. @@ -327,11 +321,11 @@ def locked() -> Optional[Callable]: This decorator has to go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Optional[Callable]: + def wrap(func: Callable) -> Callable | None: func.__locks = WeakValueDictionary() @wraps(func) - async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Optional[Callable]: + async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Callable | None: lock = func.__locks.setdefault(ctx.author.id, Lock()) if lock.locked(): embed = Embed() @@ -344,7 +338,7 @@ def locked() -> Optional[Callable]: ) embed.title = random.choice(ERROR_REPLIES) await ctx.send(embed=embed) - return + return None async with func.__locks.setdefault(ctx.author.id, Lock()): return await func(self, ctx, *args, **kwargs) diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index 3cd96325..b1a35e63 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -1,16 +1,13 @@ -from typing import Optional class UserNotPlayingError(Exception): """Raised when users try to use game commands when they are not playing.""" - pass - class APIError(Exception): """Raised when an external API (eg. Wikipedia) returns an error response.""" - def __init__(self, api: str, status_code: int, error_msg: Optional[str] = None): + def __init__(self, api: str, status_code: int, error_msg: str | None = None): super().__init__() self.api = api self.status_code = status_code diff --git a/bot/utils/messages.py b/bot/utils/messages.py index b0c95583..4fb0b39b 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -1,6 +1,7 @@ +import contextlib import logging import re -from typing import Callable, Optional, Union +from collections.abc import Callable from discord import Embed, Message from discord.ext import commands @@ -9,39 +10,35 @@ from discord.ext.commands import Context, MessageConverter log = logging.getLogger(__name__) -def sub_clyde(username: Optional[str]) -> Optional[str]: +def sub_clyde(username: str | None) -> str | None: """ - Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"E" and return the new string. + Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"Е" and return the new string. Discord disallows "clyde" anywhere in the username for webhooks. It will return a 400. Return None only if `username` is None. - """ + """ # noqa: RUF002 def replace_e(match: re.Match) -> str: - char = "е" if match[2] == "e" else "Е" + char = "е" if match[2] == "e" else "Е" # noqa: RUF001 return match[1] + char if username: return re.sub(r"(clyd)(e)", replace_e, username, flags=re.I) - else: - return username # Empty string or None + return username # Empty string or None -async def get_discord_message(ctx: Context, text: str) -> Union[Message, str]: +async def get_discord_message(ctx: Context, text: str) -> Message | str: """ Attempts to convert a given `text` to a discord Message object and return it. Conversion will succeed if given a discord Message ID or link. Returns `text` if the conversion fails. """ - try: + with contextlib.suppress(commands.BadArgument): text = await MessageConverter().convert(ctx, text) - except commands.BadArgument: - pass - return text -async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Optional[Embed]]: +async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Embed | None]: """ Attempts to extract the text and embed from a possible link to a discord Message. @@ -52,7 +49,7 @@ async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Optional[Emb str: If `text` is a valid discord Message, the contents of the message, else `text`. Optional[Embed]: The embed if found in the valid Message, else None """ - embed: Optional[Embed] = None + embed: Embed | None = None msg = await get_discord_message(ctx, text) # Ensure the user has read permissions for the channel the message is in diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index b291f7db..df0eb942 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -1,7 +1,6 @@ import asyncio import logging from collections.abc import Iterable -from typing import Optional from discord import Embed, Member, Reaction from discord.abc import User @@ -29,10 +28,10 @@ class LinePaginator(Paginator): def __init__( self, - prefix: str = '```', - suffix: str = '```', + prefix: str = "```", + suffix: str = "```", max_size: int = 2000, - max_lines: Optional[int] = None, + max_lines: int | None = None, linesep: str = "\n" ): """ @@ -87,11 +86,13 @@ class LinePaginator(Paginator): self._count += 1 @classmethod - async def paginate(cls, lines: Iterable[str], ctx: Context, embed: Embed, - prefix: str = "", suffix: str = "", max_lines: Optional[int] = None, - max_size: int = 500, empty: bool = True, restrict_to_user: User = None, - timeout: int = 300, footer_text: str = None, url: str = None, - exception_on_empty_embed: bool = False) -> None: + async def paginate( + cls, lines: Iterable[str], ctx: Context, + embed: Embed, prefix: str = "", suffix: str = "", + max_lines: int | None = None, max_size: int = 500, empty: bool = True, + restrict_to_user: User = None, timeout: int = 300, footer_text: str = None, + url: str = None, exception_on_empty_embed: bool = False + ) -> None: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -170,20 +171,20 @@ class LinePaginator(Paginator): log.debug("There's less than two pages, so we won't paginate - sending single page on its own") await ctx.send(embed=embed) - return + return None + + if footer_text: + embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") else: - if footer_text: - embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") - else: - embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}") - log.trace(f"Setting embed footer to '{embed.footer.text}'") + embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}") + log.trace(f"Setting embed footer to '{embed.footer.text}'") - if url: - embed.url = url - log.trace(f"Setting embed url to '{url}'") + if url: + embed.url = url + log.trace(f"Setting embed url to '{url}'") - log.debug("Sending first page to channel...") - message = await ctx.send(embed=embed) + log.debug("Sending first page to channel...") + message = await ctx.send(embed=embed) log.debug("Adding emoji reactions to message...") @@ -270,6 +271,7 @@ class LinePaginator(Paginator): log.debug("Ending pagination and clearing reactions...") await message.clear_reactions() + return None class ImagePaginator(Paginator): @@ -358,7 +360,7 @@ class ImagePaginator(Paginator): if len(paginator.pages) <= 1: await ctx.send(embed=embed) - return + return None embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}") message = await ctx.send(embed=embed) @@ -431,3 +433,4 @@ class ImagePaginator(Paginator): log.debug("Ending pagination and clearing reactions...") await message.clear_reactions() + return None diff --git a/bot/utils/randomization.py b/bot/utils/randomization.py index c9eabbd2..1caff3fa 100644 --- a/bot/utils/randomization.py +++ b/bot/utils/randomization.py @@ -1,7 +1,9 @@ import itertools import random from collections.abc import Iterable -from typing import Any +from typing import TypeVar + +T = TypeVar("T") class RandomCycle: @@ -11,11 +13,11 @@ class RandomCycle: The iterable is reshuffled after each full cycle. """ - def __init__(self, iterable: Iterable): + def __init__(self, iterable: Iterable[T]): self.iterable = list(iterable) self.index = itertools.cycle(range(len(iterable))) - def __next__(self) -> Any: + def __next__(self) -> T: idx = next(self.index) if idx == 0: diff --git a/bot/utils/time.py b/bot/utils/time.py index fbf2fd21..66f9e7cb 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -1,4 +1,4 @@ -import datetime +from datetime import UTC, datetime from dateutil.relativedelta import relativedelta @@ -17,12 +17,11 @@ def _stringify_time_unit(value: int, unit: str) -> str: """ if unit == "seconds" and value == 0: return "0 seconds" - elif value == 1: + if value == 1: return f"{value} {unit[:-1]}" - elif value == 0: + if value == 0: return f"less than a {unit[:-1]}" - else: - return f"{value} {unit}" + return f"{value} {unit}" def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str: @@ -69,14 +68,14 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: return humanized -def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6) -> str: +def time_since(past_datetime: datetime, precision: str = "seconds", max_units: int = 6) -> str: """ Takes a datetime and returns a human-readable string that describes how long ago that datetime was. precision specifies the smallest unit of time to include (e.g. "seconds", "minutes"). max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). """ - now = datetime.datetime.utcnow() + now = datetime.now(tz=UTC) delta = abs(relativedelta(now, past_datetime)) humanized = humanize_delta(delta, precision, max_units) |