aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
authorGravatar Chris Lovering <[email protected]>2023-05-06 16:12:32 +0100
committerGravatar Chris Lovering <[email protected]>2023-05-09 15:41:50 +0100
commit613840ebcf303e84048d48ace37fb001c1afe687 (patch)
tree9acaf0bae0527fe8389483a419b44e06997ca060 /bot/utils
parentMigrate to ruff (diff)
Apply fixes for ruff linting
Co-authored-by: wookie184 <[email protected]> Co-authored-by: Amrou Bellalouna <[email protected]>
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/__init__.py19
-rw-r--r--bot/utils/checks.py9
-rw-r--r--bot/utils/converters.py9
-rw-r--r--bot/utils/decorators.py28
-rw-r--r--bot/utils/exceptions.py5
-rw-r--r--bot/utils/messages.py25
-rw-r--r--bot/utils/pagination.py45
-rw-r--r--bot/utils/randomization.py8
-rw-r--r--bot/utils/time.py13
9 files changed, 74 insertions, 87 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 91682dbc..ddc2d111 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -3,8 +3,7 @@ import contextlib
import re
import string
from collections.abc import Iterable
-from datetime import datetime
-from typing import Optional
+from datetime import UTC, datetime
import discord
from discord.ext.commands import BadArgument, Context
@@ -27,8 +26,7 @@ def resolve_current_month() -> Month:
"""
if Client.month_override is not None:
return Month(Client.month_override)
- else:
- return Month(datetime.utcnow().month)
+ return Month(datetime.now(tz=UTC).month)
async def disambiguate(
@@ -38,7 +36,7 @@ async def disambiguate(
timeout: float = 30,
entries_per_page: int = 20,
empty: bool = False,
- embed: Optional[discord.Embed] = None
+ embed: discord.Embed | None = None
) -> str:
"""
Has the user choose between multiple entries in case one could not be chosen automatically.
@@ -130,9 +128,9 @@ def replace_many(
assert var == "That WAS a sentence"
"""
if ignore_case:
- replacements = dict(
- (word.lower(), replacement) for word, replacement in replacements.items()
- )
+ replacements = {
+ word.lower(): replacement for word, replacement in replacements.items()
+ }
words_to_replace = sorted(replacements, key=lambda s: (-len(s), s))
@@ -152,10 +150,9 @@ def replace_many(
cleaned_word = word.translate(str.maketrans("", "", string.punctuation))
if cleaned_word.isupper():
return replacement.upper()
- elif cleaned_word[0].isupper():
+ if cleaned_word[0].isupper():
return replacement.capitalize()
- else:
- return replacement.lower()
+ return replacement.lower()
return regex.sub(_repl, sentence)
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index f21d2ddd..418bb7ad 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -1,7 +1,6 @@
import datetime
import logging
-from collections.abc import Container, Iterable
-from typing import Callable, Optional
+from collections.abc import Callable, Container, Iterable
from discord.ext.commands import (
BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping
@@ -15,7 +14,7 @@ log = logging.getLogger(__name__)
class InWhitelistCheckFailure(CheckFailure):
"""Raised when the `in_whitelist` check fails."""
- def __init__(self, redirect_channel: Optional[int]):
+ def __init__(self, redirect_channel: int | None):
self.redirect_channel = redirect_channel
if redirect_channel:
@@ -33,7 +32,7 @@ def in_whitelist_check(
channels: Container[int] = (),
categories: Container[int] = (),
roles: Container[int] = (),
- redirect: Optional[int] = constants.Channels.sir_lancebot_playground,
+ redirect: int | None = constants.Channels.sir_lancebot_playground,
fail_silently: bool = False,
) -> bool:
"""
@@ -153,7 +152,7 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy
return
# Cooldown logic, taken from discord.py internals.
- current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp()
+ current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp()
bucket = buckets.get_bucket(ctx.message)
retry_after = bucket.update_rate_limit(current)
if retry_after:
diff --git a/bot/utils/converters.py b/bot/utils/converters.py
index 7227a406..6111b87d 100644
--- a/bot/utils/converters.py
+++ b/bot/utils/converters.py
@@ -1,5 +1,4 @@
-from datetime import datetime
-from typing import Union
+from datetime import UTC, datetime
import discord
from discord.ext import commands
@@ -47,7 +46,7 @@ class CoordinateConverter(commands.Converter):
return x, y
-SourceType = Union[commands.Command, commands.Cog]
+SourceType = commands.Command | commands.Cog
class SourceConverter(commands.Converter):
@@ -73,12 +72,12 @@ class DateConverter(commands.Converter):
"""Parse SOL or earth date (in format YYYY-MM-DD) into `int` or `datetime`. When invalid input, raise error."""
@staticmethod
- async def convert(ctx: commands.Context, argument: str) -> Union[int, datetime]:
+ async def convert(ctx: commands.Context, argument: str) -> int | datetime:
"""Parse date (SOL or earth) into `datetime` or `int`. When invalid value, raise error."""
if argument.isdecimal():
return int(argument)
try:
- date = datetime.strptime(argument, "%Y-%m-%d")
+ date = datetime.strptime(argument, "%Y-%m-%d").replace(tzinfo=UTC)
except ValueError:
raise commands.BadArgument(
f"Can't convert `{argument}` to `datetime` in format `YYYY-MM-DD` or `int` in SOL."
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 442eb841..1cbad504 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -3,9 +3,8 @@ import functools
import logging
import random
from asyncio import Lock
-from collections.abc import Container
+from collections.abc import Callable, Container
from functools import wraps
-from typing import Callable, Optional, Union
from weakref import WeakValueDictionary
from discord import Colour, Embed
@@ -24,16 +23,12 @@ log = logging.getLogger(__name__)
class InChannelCheckFailure(CheckFailure):
"""Check failure when the user runs a command in a non-whitelisted channel."""
- pass
-
class InMonthCheckFailure(CheckFailure):
"""Check failure for when a command is invoked outside of its allowed month."""
- pass
-
-def seasonal_task(*allowed_months: Month, sleep_time: Union[float, int] = ONE_DAY) -> Callable:
+def seasonal_task(*allowed_months: Month, sleep_time: float | int = ONE_DAY) -> Callable:
"""
Perform the decorated method periodically in `allowed_months`.
@@ -79,8 +74,8 @@ def in_month_listener(*allowed_months: Month) -> Callable:
if current_month in allowed_months:
# Propagate return value although it should always be None
return await listener(*args, **kwargs)
- else:
- log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month!s}")
+ log.debug(f"Guarded {listener.__qualname__} from invoking in {current_month!s}")
+ return None
return guarded_listener
return decorator
@@ -101,8 +96,7 @@ def in_month_command(*allowed_months: Month) -> Callable:
)
if can_run:
return True
- else:
- raise InMonthCheckFailure(f"Command can only be used in {human_months(allowed_months)}")
+ raise InMonthCheckFailure(f"Command can only be used in {human_months(allowed_months)}")
return commands.check(predicate)
@@ -201,13 +195,13 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
# Determine which command's overrides we will use. Group commands will
# inherit from their parents if they don't define their own overrides
- overridden_command: Optional[commands.Command] = None
+ overridden_command: commands.Command | None = None
for command in [ctx.command, *ctx.command.parents]:
if hasattr(command.callback, "override"):
overridden_command = command
break
if overridden_command is not None:
- log.debug(f'Command {overridden_command} has overrides')
+ log.debug(f"Command {overridden_command} has overrides")
if overridden_command is not ctx.command:
log.debug(
f"Command '{ctx.command.qualified_name}' inherited overrides "
@@ -319,7 +313,7 @@ def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **
return inner
-def locked() -> Optional[Callable]:
+def locked() -> Callable | None:
"""
Allows the user to only run one instance of the decorated command at a time.
@@ -327,11 +321,11 @@ def locked() -> Optional[Callable]:
This decorator has to go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Optional[Callable]:
+ def wrap(func: Callable) -> Callable | None:
func.__locks = WeakValueDictionary()
@wraps(func)
- async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Optional[Callable]:
+ async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Callable | None:
lock = func.__locks.setdefault(ctx.author.id, Lock())
if lock.locked():
embed = Embed()
@@ -344,7 +338,7 @@ def locked() -> Optional[Callable]:
)
embed.title = random.choice(ERROR_REPLIES)
await ctx.send(embed=embed)
- return
+ return None
async with func.__locks.setdefault(ctx.author.id, Lock()):
return await func(self, ctx, *args, **kwargs)
diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py
index 3cd96325..b1a35e63 100644
--- a/bot/utils/exceptions.py
+++ b/bot/utils/exceptions.py
@@ -1,16 +1,13 @@
-from typing import Optional
class UserNotPlayingError(Exception):
"""Raised when users try to use game commands when they are not playing."""
- pass
-
class APIError(Exception):
"""Raised when an external API (eg. Wikipedia) returns an error response."""
- def __init__(self, api: str, status_code: int, error_msg: Optional[str] = None):
+ def __init__(self, api: str, status_code: int, error_msg: str | None = None):
super().__init__()
self.api = api
self.status_code = status_code
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index b0c95583..4fb0b39b 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -1,6 +1,7 @@
+import contextlib
import logging
import re
-from typing import Callable, Optional, Union
+from collections.abc import Callable
from discord import Embed, Message
from discord.ext import commands
@@ -9,39 +10,35 @@ from discord.ext.commands import Context, MessageConverter
log = logging.getLogger(__name__)
-def sub_clyde(username: Optional[str]) -> Optional[str]:
+def sub_clyde(username: str | None) -> str | None:
"""
- Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"E" and return the new string.
+ Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"Е" and return the new string.
Discord disallows "clyde" anywhere in the username for webhooks. It will return a 400.
Return None only if `username` is None.
- """
+ """ # noqa: RUF002
def replace_e(match: re.Match) -> str:
- char = "е" if match[2] == "e" else "Е"
+ char = "е" if match[2] == "e" else "Е" # noqa: RUF001
return match[1] + char
if username:
return re.sub(r"(clyd)(e)", replace_e, username, flags=re.I)
- else:
- return username # Empty string or None
+ return username # Empty string or None
-async def get_discord_message(ctx: Context, text: str) -> Union[Message, str]:
+async def get_discord_message(ctx: Context, text: str) -> Message | str:
"""
Attempts to convert a given `text` to a discord Message object and return it.
Conversion will succeed if given a discord Message ID or link.
Returns `text` if the conversion fails.
"""
- try:
+ with contextlib.suppress(commands.BadArgument):
text = await MessageConverter().convert(ctx, text)
- except commands.BadArgument:
- pass
-
return text
-async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Optional[Embed]]:
+async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Embed | None]:
"""
Attempts to extract the text and embed from a possible link to a discord Message.
@@ -52,7 +49,7 @@ async def get_text_and_embed(ctx: Context, text: str) -> tuple[str, Optional[Emb
str: If `text` is a valid discord Message, the contents of the message, else `text`.
Optional[Embed]: The embed if found in the valid Message, else None
"""
- embed: Optional[Embed] = None
+ embed: Embed | None = None
msg = await get_discord_message(ctx, text)
# Ensure the user has read permissions for the channel the message is in
diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py
index b291f7db..df0eb942 100644
--- a/bot/utils/pagination.py
+++ b/bot/utils/pagination.py
@@ -1,7 +1,6 @@
import asyncio
import logging
from collections.abc import Iterable
-from typing import Optional
from discord import Embed, Member, Reaction
from discord.abc import User
@@ -29,10 +28,10 @@ class LinePaginator(Paginator):
def __init__(
self,
- prefix: str = '```',
- suffix: str = '```',
+ prefix: str = "```",
+ suffix: str = "```",
max_size: int = 2000,
- max_lines: Optional[int] = None,
+ max_lines: int | None = None,
linesep: str = "\n"
):
"""
@@ -87,11 +86,13 @@ class LinePaginator(Paginator):
self._count += 1
@classmethod
- async def paginate(cls, lines: Iterable[str], ctx: Context, embed: Embed,
- prefix: str = "", suffix: str = "", max_lines: Optional[int] = None,
- max_size: int = 500, empty: bool = True, restrict_to_user: User = None,
- timeout: int = 300, footer_text: str = None, url: str = None,
- exception_on_empty_embed: bool = False) -> None:
+ async def paginate(
+ cls, lines: Iterable[str], ctx: Context,
+ embed: Embed, prefix: str = "", suffix: str = "",
+ max_lines: int | None = None, max_size: int = 500, empty: bool = True,
+ restrict_to_user: User = None, timeout: int = 300, footer_text: str = None,
+ url: str = None, exception_on_empty_embed: bool = False
+ ) -> None:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -170,20 +171,20 @@ class LinePaginator(Paginator):
log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
await ctx.send(embed=embed)
- return
+ return None
+
+ if footer_text:
+ embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
else:
- if footer_text:
- embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
- else:
- embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")
- log.trace(f"Setting embed footer to '{embed.footer.text}'")
+ embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")
+ log.trace(f"Setting embed footer to '{embed.footer.text}'")
- if url:
- embed.url = url
- log.trace(f"Setting embed url to '{url}'")
+ if url:
+ embed.url = url
+ log.trace(f"Setting embed url to '{url}'")
- log.debug("Sending first page to channel...")
- message = await ctx.send(embed=embed)
+ log.debug("Sending first page to channel...")
+ message = await ctx.send(embed=embed)
log.debug("Adding emoji reactions to message...")
@@ -270,6 +271,7 @@ class LinePaginator(Paginator):
log.debug("Ending pagination and clearing reactions...")
await message.clear_reactions()
+ return None
class ImagePaginator(Paginator):
@@ -358,7 +360,7 @@ class ImagePaginator(Paginator):
if len(paginator.pages) <= 1:
await ctx.send(embed=embed)
- return
+ return None
embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")
message = await ctx.send(embed=embed)
@@ -431,3 +433,4 @@ class ImagePaginator(Paginator):
log.debug("Ending pagination and clearing reactions...")
await message.clear_reactions()
+ return None
diff --git a/bot/utils/randomization.py b/bot/utils/randomization.py
index c9eabbd2..1caff3fa 100644
--- a/bot/utils/randomization.py
+++ b/bot/utils/randomization.py
@@ -1,7 +1,9 @@
import itertools
import random
from collections.abc import Iterable
-from typing import Any
+from typing import TypeVar
+
+T = TypeVar("T")
class RandomCycle:
@@ -11,11 +13,11 @@ class RandomCycle:
The iterable is reshuffled after each full cycle.
"""
- def __init__(self, iterable: Iterable):
+ def __init__(self, iterable: Iterable[T]):
self.iterable = list(iterable)
self.index = itertools.cycle(range(len(iterable)))
- def __next__(self) -> Any:
+ def __next__(self) -> T:
idx = next(self.index)
if idx == 0:
diff --git a/bot/utils/time.py b/bot/utils/time.py
index fbf2fd21..66f9e7cb 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -1,4 +1,4 @@
-import datetime
+from datetime import UTC, datetime
from dateutil.relativedelta import relativedelta
@@ -17,12 +17,11 @@ def _stringify_time_unit(value: int, unit: str) -> str:
"""
if unit == "seconds" and value == 0:
return "0 seconds"
- elif value == 1:
+ if value == 1:
return f"{value} {unit[:-1]}"
- elif value == 0:
+ if value == 0:
return f"less than a {unit[:-1]}"
- else:
- return f"{value} {unit}"
+ return f"{value} {unit}"
def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units: int = 6) -> str:
@@ -69,14 +68,14 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units:
return humanized
-def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max_units: int = 6) -> str:
+def time_since(past_datetime: datetime, precision: str = "seconds", max_units: int = 6) -> str:
"""
Takes a datetime and returns a human-readable string that describes how long ago that datetime was.
precision specifies the smallest unit of time to include (e.g. "seconds", "minutes").
max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
"""
- now = datetime.datetime.utcnow()
+ now = datetime.now(tz=UTC)
delta = abs(relativedelta(now, past_datetime))
humanized = humanize_delta(delta, precision, max_units)