diff options
Diffstat (limited to 'botcore')
-rw-r--r-- | botcore/_bot.py | 21 | ||||
-rw-r--r-- | botcore/site_api.py | 5 | ||||
-rw-r--r-- | botcore/utils/__init__.py | 16 | ||||
-rw-r--r-- | botcore/utils/_monkey_patches.py | 5 | ||||
-rw-r--r-- | botcore/utils/commands.py | 38 | ||||
-rw-r--r-- | botcore/utils/cooldown.py | 3 | ||||
-rw-r--r-- | botcore/utils/function.py | 3 | ||||
-rw-r--r-- | botcore/utils/interactions.py | 98 | ||||
-rw-r--r-- | botcore/utils/members.py | 9 | ||||
-rw-r--r-- | botcore/utils/regex.py | 7 | ||||
-rw-r--r-- | botcore/utils/scheduling.py | 32 |
11 files changed, 196 insertions, 41 deletions
diff --git a/botcore/_bot.py b/botcore/_bot.py index e9eba5c5..bb25c0b5 100644 --- a/botcore/_bot.py +++ b/botcore/_bot.py @@ -51,7 +51,7 @@ class BotBase(commands.Bot): Initialise the base bot instance. Args: - guild_id: The ID of the guild use for :func:`wait_until_guild_available`. + guild_id: The ID of the guild used for :func:`wait_until_guild_available`. allowed_roles: A list of role IDs that the bot is allowed to mention. http_session (aiohttp.ClientSession): The session to use for the bot. redis_session: The `async_rediscache.RedisSession`_ to use for the bot. @@ -197,7 +197,7 @@ class BotBase(commands.Bot): if not guild.roles or not guild.members or not guild.channels: msg = "Guild available event was dispatched but the cache appears to still be empty!" - self.log_to_dev_log(msg) + await self.log_to_dev_log(msg) return self._guild_available.set() @@ -234,14 +234,16 @@ class BotBase(commands.Bot): ) self.http.connector = self._connector - if getattr(self, "redis_session", False) and self.redis_session.closed: + if getattr(self, "redis_session", False) and not self.redis_session.valid: # If the RedisSession was somehow closed, we try to reconnect it # here. Normally, this shouldn't happen. - await self.redis_session.connect() + await self.redis_session.connect(ping=True) - # Create dummy stats client first, in case `statsd_url` is unreachable within `_connect_statsd()` + # Create dummy stats client first, in case `statsd_url` is unreachable or None self.stats = AsyncStatsClient(loop, "127.0.0.1") - self._connect_statsd(self.statsd_url, loop) + if self.statsd_url: + self._connect_statsd(self.statsd_url, loop) + await self.stats.create_socket() try: @@ -249,7 +251,7 @@ class BotBase(commands.Bot): except Exception as e: raise StartupError(e) - async def ping_services() -> None: + async def ping_services(self) -> None: """Ping all required services on setup to ensure they are up before starting.""" ... @@ -279,11 +281,8 @@ class BotBase(commands.Bot): if self._resolver: await self._resolver.close() - if self.stats._transport: + if getattr(self.stats, "_transport", False): self.stats._transport.close() - if getattr(self, "redis_session", False): - await self.redis_session.close() - if self._statsd_timerhandle: self._statsd_timerhandle.cancel() diff --git a/botcore/site_api.py b/botcore/site_api.py index dbdf4f3b..44309f9d 100644 --- a/botcore/site_api.py +++ b/botcore/site_api.py @@ -26,7 +26,7 @@ class ResponseCodeError(ValueError): Args: response (:obj:`aiohttp.ClientResponse`): The response object from the request. response_json: The JSON response returned from the request, if any. - request_text: The text of the request, if any. + response_text: The text of the request, if any. """ self.status = response.status self.response_json = response_json or {} @@ -76,7 +76,8 @@ class APIClient: """Close the aiohttp session.""" await self.session.close() - async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: + @staticmethod + async def maybe_raise_for_status(response: aiohttp.ClientResponse, should_raise: bool) -> None: """ Raise :exc:`ResponseCodeError` for non-OK response if an exception should be raised. diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index cfc5e99d..09aaa45f 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,18 @@ """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, cooldown, function, logging, members, regex, scheduling +from botcore.utils import ( + _monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, +) from botcore.utils._extensions import unqualify @@ -24,8 +36,10 @@ __all__ = [ apply_monkey_patches, caching, channel, + commands, cooldown, function, + interactions, logging, members, regex, diff --git a/botcore/utils/_monkey_patches.py b/botcore/utils/_monkey_patches.py index f2c6c100..c2f8aa10 100644 --- a/botcore/utils/_monkey_patches.py +++ b/botcore/utils/_monkey_patches.py @@ -1,6 +1,7 @@ """Contains all common monkey patches, used to alter discord to fit our needs.""" import logging +import typing from datetime import datetime, timedelta from functools import partial, partialmethod @@ -46,9 +47,9 @@ def _patch_typing() -> None: log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!") original = http.HTTPClient.send_typing - last_403 = None + last_403: typing.Optional[datetime] = None - async def honeybadger_type(self, channel_id: int) -> None: # noqa: ANN001 + async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None: nonlocal last_403 if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5): log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") diff --git a/botcore/utils/commands.py b/botcore/utils/commands.py new file mode 100644 index 00000000..7afd8137 --- /dev/null +++ b/botcore/utils/commands.py @@ -0,0 +1,38 @@ +from typing import Optional + +from discord import Message +from discord.ext.commands import BadArgument, Context, clean_content + + +async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str: + """ + Cleans a text argument or replied message's content. + + Args: + ctx: The command's context + text: The provided text argument of the command (if given) + + Raises: + :exc:`discord.ext.commands.BadArgument` + `text` wasn't provided and there's no reply message / reply message content. + + Returns: + The cleaned version of `text`, if given, else replied message. + """ + clean_content_converter = clean_content(fix_channel_mentions=True) + + if text: + return await clean_content_converter.convert(ctx, text) + + if ( + (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference + and isinstance(replied_message, Message) # referenced message hasn't been deleted + ): + if not (content := ctx.message.reference.resolved.content): + # The referenced message doesn't have a content (e.g. embed/image), so raise error + raise BadArgument("The referenced message doesn't have a text content.") + + return await clean_content_converter.convert(ctx, content) + + # No text provided, and either no message was referenced or we can't access the content + raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.") diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py index b9149b48..ee65033d 100644 --- a/botcore/utils/cooldown.py +++ b/botcore/utils/cooldown.py @@ -7,10 +7,9 @@ import random import time import typing import weakref -from collections.abc import Awaitable, Hashable, Iterable +from collections.abc import Awaitable, Callable, Hashable, Iterable from contextlib import suppress from dataclasses import dataclass -from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable import discord from discord.ext.commands import CommandError, Context diff --git a/botcore/utils/function.py b/botcore/utils/function.py index e8d24e90..0e90d4c5 100644 --- a/botcore/utils/function.py +++ b/botcore/utils/function.py @@ -5,8 +5,7 @@ from __future__ import annotations import functools import types import typing -from collections.abc import Sequence, Set -from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable +from collections.abc import Callable, Sequence, Set __all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] diff --git a/botcore/utils/interactions.py b/botcore/utils/interactions.py new file mode 100644 index 00000000..26bd92f2 --- /dev/null +++ b/botcore/utils/interactions.py @@ -0,0 +1,98 @@ +import contextlib +from typing import Optional, Sequence + +from discord import ButtonStyle, Interaction, Message, NotFound, ui + +from botcore.utils.logging import get_logger + +log = get_logger(__name__) + + +class ViewWithUserAndRoleCheck(ui.View): + """ + A view that allows the original invoker and moderators to interact with it. + + Args: + allowed_users: A sequence of user's ids who are allowed to interact with the view. + allowed_roles: A sequence of role ids that are allowed to interact with the view. + timeout: Timeout in seconds from last interaction with the UI before no longer accepting input. + If ``None`` then there is no timeout. + message: The message to remove the view from on timeout. This can also be set with + ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated. + """ + + def __init__( + self, + *, + allowed_users: Sequence[int], + allowed_roles: Sequence[int], + timeout: Optional[float] = 180.0, + message: Optional[Message] = None + ) -> None: + super().__init__(timeout=timeout) + self.allowed_users = allowed_users + self.allowed_roles = allowed_roles + self.message = message + + async def interaction_check(self, interaction: Interaction) -> bool: + """ + Ensure the user clicking the button is the view invoker, or a moderator. + + Args: + interaction: The interaction that occurred. + """ + if interaction.user.id in self.allowed_users: + log.trace( + "Allowed interaction by %s (%d) on %d as they are an allowed user.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): + log.trace( + "Allowed interaction by %s (%d)on %d as they have an allowed role.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + await interaction.response.send_message("This is not your button to click!", ephemeral=True) + return False + + async def on_timeout(self) -> None: + """Remove the view from ``self.message`` if set.""" + if self.message: + with contextlib.suppress(NotFound): + # Cover the case where this message has already been deleted by external means + await self.message.edit(view=None) + + +class DeleteMessageButton(ui.Button): + """ + A button that can be added to a view to delete the message containing the view on click. + + This button itself carries out no interaction checks, these should be done by the parent view. + + See :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks. + + Args: + style (:literal-url:`ButtonStyle <https://discordpy.readthedocs.io/en/latest/interactions/api.html#discord.ButtonStyle>`): + The style of the button, set to ``ButtonStyle.secondary`` if not specified. + label: The label of the button, set to "Delete" if not specified. + """ # noqa: E501 + + def __init__( + self, + *, + style: ButtonStyle = ButtonStyle.secondary, + label: str = "Delete", + **kwargs + ): + super().__init__(style=style, label=label, **kwargs) + + async def callback(self, interaction: Interaction) -> None: + """Delete the original message on button click.""" + await interaction.message.delete() diff --git a/botcore/utils/members.py b/botcore/utils/members.py index e89b4618..1536a8d1 100644 --- a/botcore/utils/members.py +++ b/botcore/utils/members.py @@ -1,6 +1,6 @@ """Useful helper functions for interactin with :obj:`discord.Member` objects.""" - import typing +from collections import abc import discord @@ -30,18 +30,19 @@ async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Op async def handle_role_change( member: discord.Member, - coro: typing.Callable[..., typing.Coroutine], + coro: typing.Callable[[discord.Role], abc.Coroutine], role: discord.Role ) -> None: """ - Await the given ``coro`` with ``member`` as the sole argument. + Await the given ``coro`` with ``role`` as the sole argument. Handle errors that we expect to be raised from :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`. Args: - member: The member to pass to ``coro``. + member: The member that is being modified for logging purposes. coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`. + role: The role to be passed to ``coro``. """ try: await coro(role) diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py index 56c50dad..de82a1ed 100644 --- a/botcore/utils/regex.py +++ b/botcore/utils/regex.py @@ -3,6 +3,7 @@ import re DISCORD_INVITE = re.compile( + r"(https?://)?(www\.)?" # Optional http(s) and www. r"(discord([.,]|dot)gg|" # Could be discord.gg/ r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ @@ -10,7 +11,7 @@ DISCORD_INVITE = re.compile( r"discord([.,]|dot)li|" # or discord.li r"discord([.,]|dot)io|" # or discord.io. r"((?<!\w)([.,]|dot))gg" # or .gg/ - r")([/]|slash)" # / or 'slash' + r")(/|slash)" # / or 'slash' r"(?P<invite>\S+)", # the invite code itself flags=re.IGNORECASE ) @@ -32,7 +33,7 @@ FORMATTED_CODE_REGEX = re.compile( r"(?P<code>.*?)" # extract all code inside the markup r"\s*" # any more whitespace before the end of the code markup r"(?P=delim)", # match the exact same delimiter from the start again - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive + flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive ) """ Regex for formatted code, using Discord's code blocks. @@ -44,7 +45,7 @@ RAW_CODE_REGEX = re.compile( r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code r"(?P<code>.*?)" # extract all the rest as code r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines + flags=re.DOTALL # "." also matches newlines ) """ Regex for raw code, *not* using Discord's code blocks. diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py index 164f6b10..9517df6d 100644 --- a/botcore/utils/scheduling.py +++ b/botcore/utils/scheduling.py @@ -4,6 +4,7 @@ import asyncio import contextlib import inspect import typing +from collections import abc from datetime import datetime from functools import partial @@ -38,9 +39,9 @@ class Scheduler: self.name = name self._log = logging.get_logger(f"{__name__}.{name}") - self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {} + self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {} - def __contains__(self, task_id: typing.Hashable) -> bool: + def __contains__(self, task_id: abc.Hashable) -> bool: """ Return :obj:`True` if a task with the given ``task_id`` is currently scheduled. @@ -52,7 +53,7 @@ class Scheduler: """ return task_id in self._scheduled_tasks - def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: + def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: """ Schedule the execution of a ``coroutine``. @@ -79,7 +80,7 @@ class Scheduler: self._scheduled_tasks[task_id] = task self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: + def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: """ Schedule ``coroutine`` to be executed at the given ``time``. @@ -106,8 +107,8 @@ class Scheduler: def schedule_later( self, delay: typing.Union[int, float], - task_id: typing.Hashable, - coroutine: typing.Coroutine + task_id: abc.Hashable, + coroutine: abc.Coroutine ) -> None: """ Schedule ``coroutine`` to be executed after ``delay`` seconds. @@ -122,7 +123,7 @@ class Scheduler: """ self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - def cancel(self, task_id: typing.Hashable) -> None: + def cancel(self, task_id: abc.Hashable) -> None: """ Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist. @@ -150,8 +151,8 @@ class Scheduler: async def _await_later( self, delay: typing.Union[int, float], - task_id: typing.Hashable, - coroutine: typing.Coroutine + task_id: abc.Hashable, + coroutine: abc.Coroutine ) -> None: """Await ``coroutine`` after ``delay`` seconds.""" try: @@ -173,7 +174,7 @@ class Scheduler: else: self._log.debug(f"Finally block reached for #{task_id}; {state=}") - def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None: + def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None: """ Delete the task and raise its exception if one exists. @@ -208,13 +209,16 @@ class Scheduler: self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) +TASK_RETURN = typing.TypeVar("TASK_RETURN") + + def create_task( - coro: typing.Awaitable, + coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN], *, - suppressed_exceptions: tuple[typing.Type[Exception]] = (), + suppressed_exceptions: tuple[type[Exception], ...] = (), event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, **kwargs, -) -> asyncio.Task: +) -> asyncio.Task[TASK_RETURN]: """ Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task. @@ -238,7 +242,7 @@ def create_task( return task -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None: +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None: """Retrieve and log the exception raised in ``task`` if one exists.""" with contextlib.suppress(asyncio.CancelledError): exception = task.exception() |