aboutsummaryrefslogtreecommitdiffstats
path: root/botcore/utils
diff options
context:
space:
mode:
authorGravatar Numerlor <[email protected]>2022-09-18 19:10:24 +0200
committerGravatar Numerlor <[email protected]>2022-09-18 19:14:08 +0200
commitb6f033e7f5fcdb827e7fed29a4ed21108e54a414 (patch)
tree99be74f8d90217e8d2dbeba442afce7ea04d5de6 /botcore/utils
parentensure tuples from pos arg and kwarg tuples are differentiated (diff)
parentMerge pull request #138 from python-discord/bump-d.py (diff)
Merge remote-tracking branch 'upstream/main' into no-duplicate-deco
Diffstat (limited to 'botcore/utils')
-rw-r--r--botcore/utils/__init__.py16
-rw-r--r--botcore/utils/_monkey_patches.py5
-rw-r--r--botcore/utils/commands.py38
-rw-r--r--botcore/utils/cooldown.py3
-rw-r--r--botcore/utils/function.py3
-rw-r--r--botcore/utils/interactions.py98
-rw-r--r--botcore/utils/members.py9
-rw-r--r--botcore/utils/regex.py7
-rw-r--r--botcore/utils/scheduling.py32
9 files changed, 183 insertions, 28 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py
index cfc5e99d..09aaa45f 100644
--- a/botcore/utils/__init__.py
+++ b/botcore/utils/__init__.py
@@ -1,6 +1,18 @@
"""Useful utilities and tools for Discord bot development."""
-from botcore.utils import _monkey_patches, caching, channel, cooldown, function, logging, members, regex, scheduling
+from botcore.utils import (
+ _monkey_patches,
+ caching,
+ channel,
+ commands,
+ cooldown,
+ function,
+ interactions,
+ logging,
+ members,
+ regex,
+ scheduling,
+)
from botcore.utils._extensions import unqualify
@@ -24,8 +36,10 @@ __all__ = [
apply_monkey_patches,
caching,
channel,
+ commands,
cooldown,
function,
+ interactions,
logging,
members,
regex,
diff --git a/botcore/utils/_monkey_patches.py b/botcore/utils/_monkey_patches.py
index f2c6c100..c2f8aa10 100644
--- a/botcore/utils/_monkey_patches.py
+++ b/botcore/utils/_monkey_patches.py
@@ -1,6 +1,7 @@
"""Contains all common monkey patches, used to alter discord to fit our needs."""
import logging
+import typing
from datetime import datetime, timedelta
from functools import partial, partialmethod
@@ -46,9 +47,9 @@ def _patch_typing() -> None:
log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!")
original = http.HTTPClient.send_typing
- last_403 = None
+ last_403: typing.Optional[datetime] = None
- async def honeybadger_type(self, channel_id: int) -> None: # noqa: ANN001
+ async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None:
nonlocal last_403
if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5):
log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.")
diff --git a/botcore/utils/commands.py b/botcore/utils/commands.py
new file mode 100644
index 00000000..7afd8137
--- /dev/null
+++ b/botcore/utils/commands.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+from discord import Message
+from discord.ext.commands import BadArgument, Context, clean_content
+
+
+async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str:
+ """
+ Cleans a text argument or replied message's content.
+
+ Args:
+ ctx: The command's context
+ text: The provided text argument of the command (if given)
+
+ Raises:
+ :exc:`discord.ext.commands.BadArgument`
+ `text` wasn't provided and there's no reply message / reply message content.
+
+ Returns:
+ The cleaned version of `text`, if given, else replied message.
+ """
+ clean_content_converter = clean_content(fix_channel_mentions=True)
+
+ if text:
+ return await clean_content_converter.convert(ctx, text)
+
+ if (
+ (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference
+ and isinstance(replied_message, Message) # referenced message hasn't been deleted
+ ):
+ if not (content := ctx.message.reference.resolved.content):
+ # The referenced message doesn't have a content (e.g. embed/image), so raise error
+ raise BadArgument("The referenced message doesn't have a text content.")
+
+ return await clean_content_converter.convert(ctx, content)
+
+ # No text provided, and either no message was referenced or we can't access the content
+ raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.")
diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py
index b9149b48..ee65033d 100644
--- a/botcore/utils/cooldown.py
+++ b/botcore/utils/cooldown.py
@@ -7,10 +7,9 @@ import random
import time
import typing
import weakref
-from collections.abc import Awaitable, Hashable, Iterable
+from collections.abc import Awaitable, Callable, Hashable, Iterable
from contextlib import suppress
from dataclasses import dataclass
-from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable
import discord
from discord.ext.commands import CommandError, Context
diff --git a/botcore/utils/function.py b/botcore/utils/function.py
index e8d24e90..0e90d4c5 100644
--- a/botcore/utils/function.py
+++ b/botcore/utils/function.py
@@ -5,8 +5,7 @@ from __future__ import annotations
import functools
import types
import typing
-from collections.abc import Sequence, Set
-from typing import Callable # sphinx-autodoc-typehints breaks with collections.abc.Callable
+from collections.abc import Callable, Sequence, Set
__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]
diff --git a/botcore/utils/interactions.py b/botcore/utils/interactions.py
new file mode 100644
index 00000000..26bd92f2
--- /dev/null
+++ b/botcore/utils/interactions.py
@@ -0,0 +1,98 @@
+import contextlib
+from typing import Optional, Sequence
+
+from discord import ButtonStyle, Interaction, Message, NotFound, ui
+
+from botcore.utils.logging import get_logger
+
+log = get_logger(__name__)
+
+
+class ViewWithUserAndRoleCheck(ui.View):
+ """
+ A view that allows the original invoker and moderators to interact with it.
+
+ Args:
+ allowed_users: A sequence of user's ids who are allowed to interact with the view.
+ allowed_roles: A sequence of role ids that are allowed to interact with the view.
+ timeout: Timeout in seconds from last interaction with the UI before no longer accepting input.
+ If ``None`` then there is no timeout.
+ message: The message to remove the view from on timeout. This can also be set with
+ ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated.
+ """
+
+ def __init__(
+ self,
+ *,
+ allowed_users: Sequence[int],
+ allowed_roles: Sequence[int],
+ timeout: Optional[float] = 180.0,
+ message: Optional[Message] = None
+ ) -> None:
+ super().__init__(timeout=timeout)
+ self.allowed_users = allowed_users
+ self.allowed_roles = allowed_roles
+ self.message = message
+
+ async def interaction_check(self, interaction: Interaction) -> bool:
+ """
+ Ensure the user clicking the button is the view invoker, or a moderator.
+
+ Args:
+ interaction: The interaction that occurred.
+ """
+ if interaction.user.id in self.allowed_users:
+ log.trace(
+ "Allowed interaction by %s (%d) on %d as they are an allowed user.",
+ interaction.user,
+ interaction.user.id,
+ interaction.message.id,
+ )
+ return True
+
+ if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])):
+ log.trace(
+ "Allowed interaction by %s (%d)on %d as they have an allowed role.",
+ interaction.user,
+ interaction.user.id,
+ interaction.message.id,
+ )
+ return True
+
+ await interaction.response.send_message("This is not your button to click!", ephemeral=True)
+ return False
+
+ async def on_timeout(self) -> None:
+ """Remove the view from ``self.message`` if set."""
+ if self.message:
+ with contextlib.suppress(NotFound):
+ # Cover the case where this message has already been deleted by external means
+ await self.message.edit(view=None)
+
+
+class DeleteMessageButton(ui.Button):
+ """
+ A button that can be added to a view to delete the message containing the view on click.
+
+ This button itself carries out no interaction checks, these should be done by the parent view.
+
+ See :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks.
+
+ Args:
+ style (:literal-url:`ButtonStyle <https://discordpy.readthedocs.io/en/latest/interactions/api.html#discord.ButtonStyle>`):
+ The style of the button, set to ``ButtonStyle.secondary`` if not specified.
+ label: The label of the button, set to "Delete" if not specified.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ *,
+ style: ButtonStyle = ButtonStyle.secondary,
+ label: str = "Delete",
+ **kwargs
+ ):
+ super().__init__(style=style, label=label, **kwargs)
+
+ async def callback(self, interaction: Interaction) -> None:
+ """Delete the original message on button click."""
+ await interaction.message.delete()
diff --git a/botcore/utils/members.py b/botcore/utils/members.py
index e89b4618..1536a8d1 100644
--- a/botcore/utils/members.py
+++ b/botcore/utils/members.py
@@ -1,6 +1,6 @@
"""Useful helper functions for interactin with :obj:`discord.Member` objects."""
-
import typing
+from collections import abc
import discord
@@ -30,18 +30,19 @@ async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Op
async def handle_role_change(
member: discord.Member,
- coro: typing.Callable[..., typing.Coroutine],
+ coro: typing.Callable[[discord.Role], abc.Coroutine],
role: discord.Role
) -> None:
"""
- Await the given ``coro`` with ``member`` as the sole argument.
+ Await the given ``coro`` with ``role`` as the sole argument.
Handle errors that we expect to be raised from
:obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`.
Args:
- member: The member to pass to ``coro``.
+ member: The member that is being modified for logging purposes.
coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`.
+ role: The role to be passed to ``coro``.
"""
try:
await coro(role)
diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py
index 56c50dad..de82a1ed 100644
--- a/botcore/utils/regex.py
+++ b/botcore/utils/regex.py
@@ -3,6 +3,7 @@
import re
DISCORD_INVITE = re.compile(
+ r"(https?://)?(www\.)?" # Optional http(s) and www.
r"(discord([.,]|dot)gg|" # Could be discord.gg/
r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/
r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/
@@ -10,7 +11,7 @@ DISCORD_INVITE = re.compile(
r"discord([.,]|dot)li|" # or discord.li
r"discord([.,]|dot)io|" # or discord.io.
r"((?<!\w)([.,]|dot))gg" # or .gg/
- r")([/]|slash)" # / or 'slash'
+ r")(/|slash)" # / or 'slash'
r"(?P<invite>\S+)", # the invite code itself
flags=re.IGNORECASE
)
@@ -32,7 +33,7 @@ FORMATTED_CODE_REGEX = re.compile(
r"(?P<code>.*?)" # extract all code inside the markup
r"\s*" # any more whitespace before the end of the code markup
r"(?P=delim)", # match the exact same delimiter from the start again
- re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
+ flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
)
"""
Regex for formatted code, using Discord's code blocks.
@@ -44,7 +45,7 @@ RAW_CODE_REGEX = re.compile(
r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
r"(?P<code>.*?)" # extract all the rest as code
r"\s*$", # any trailing whitespace until the end of the string
- re.DOTALL # "." also matches newlines
+ flags=re.DOTALL # "." also matches newlines
)
"""
Regex for raw code, *not* using Discord's code blocks.
diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py
index 164f6b10..9517df6d 100644
--- a/botcore/utils/scheduling.py
+++ b/botcore/utils/scheduling.py
@@ -4,6 +4,7 @@ import asyncio
import contextlib
import inspect
import typing
+from collections import abc
from datetime import datetime
from functools import partial
@@ -38,9 +39,9 @@ class Scheduler:
self.name = name
self._log = logging.get_logger(f"{__name__}.{name}")
- self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {}
+ self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {}
- def __contains__(self, task_id: typing.Hashable) -> bool:
+ def __contains__(self, task_id: abc.Hashable) -> bool:
"""
Return :obj:`True` if a task with the given ``task_id`` is currently scheduled.
@@ -52,7 +53,7 @@ class Scheduler:
"""
return task_id in self._scheduled_tasks
- def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:
"""
Schedule the execution of a ``coroutine``.
@@ -79,7 +80,7 @@ class Scheduler:
self._scheduled_tasks[task_id] = task
self._log.debug(f"Scheduled task #{task_id} {id(task)}.")
- def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:
"""
Schedule ``coroutine`` to be executed at the given ``time``.
@@ -106,8 +107,8 @@ class Scheduler:
def schedule_later(
self,
delay: typing.Union[int, float],
- task_id: typing.Hashable,
- coroutine: typing.Coroutine
+ task_id: abc.Hashable,
+ coroutine: abc.Coroutine
) -> None:
"""
Schedule ``coroutine`` to be executed after ``delay`` seconds.
@@ -122,7 +123,7 @@ class Scheduler:
"""
self.schedule(task_id, self._await_later(delay, task_id, coroutine))
- def cancel(self, task_id: typing.Hashable) -> None:
+ def cancel(self, task_id: abc.Hashable) -> None:
"""
Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist.
@@ -150,8 +151,8 @@ class Scheduler:
async def _await_later(
self,
delay: typing.Union[int, float],
- task_id: typing.Hashable,
- coroutine: typing.Coroutine
+ task_id: abc.Hashable,
+ coroutine: abc.Coroutine
) -> None:
"""Await ``coroutine`` after ``delay`` seconds."""
try:
@@ -173,7 +174,7 @@ class Scheduler:
else:
self._log.debug(f"Finally block reached for #{task_id}; {state=}")
- def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None:
+ def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None:
"""
Delete the task and raise its exception if one exists.
@@ -208,13 +209,16 @@ class Scheduler:
self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
+TASK_RETURN = typing.TypeVar("TASK_RETURN")
+
+
def create_task(
- coro: typing.Awaitable,
+ coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN],
*,
- suppressed_exceptions: tuple[typing.Type[Exception]] = (),
+ suppressed_exceptions: tuple[type[Exception], ...] = (),
event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
-) -> asyncio.Task:
+) -> asyncio.Task[TASK_RETURN]:
"""
Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task.
@@ -238,7 +242,7 @@ def create_task(
return task
-def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None:
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None:
"""Retrieve and log the exception raised in ``task`` if one exists."""
with contextlib.suppress(asyncio.CancelledError):
exception = task.exception()