diff options
| author | 2022-09-18 19:10:24 +0200 | |
|---|---|---|
| committer | 2022-09-18 19:14:08 +0200 | |
| commit | b6f033e7f5fcdb827e7fed29a4ed21108e54a414 (patch) | |
| tree | 99be74f8d90217e8d2dbeba442afce7ea04d5de6 /botcore/utils | |
| parent | ensure tuples from pos arg and kwarg tuples are differentiated (diff) | |
| parent | Merge pull request #138 from python-discord/bump-d.py (diff) | |
Merge remote-tracking branch 'upstream/main' into no-duplicate-deco
Diffstat (limited to 'botcore/utils')
| -rw-r--r-- | botcore/utils/__init__.py | 16 | ||||
| -rw-r--r-- | botcore/utils/_monkey_patches.py | 5 | ||||
| -rw-r--r-- | botcore/utils/commands.py | 38 | ||||
| -rw-r--r-- | botcore/utils/cooldown.py | 3 | ||||
| -rw-r--r-- | botcore/utils/function.py | 3 | ||||
| -rw-r--r-- | botcore/utils/interactions.py | 98 | ||||
| -rw-r--r-- | botcore/utils/members.py | 9 | ||||
| -rw-r--r-- | botcore/utils/regex.py | 7 | ||||
| -rw-r--r-- | botcore/utils/scheduling.py | 32 | 
9 files changed, 183 insertions, 28 deletions
| diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py index cfc5e99d..09aaa45f 100644 --- a/botcore/utils/__init__.py +++ b/botcore/utils/__init__.py @@ -1,6 +1,18 @@  """Useful utilities and tools for Discord bot development.""" -from botcore.utils import _monkey_patches, caching, channel, cooldown, function, logging, members, regex, scheduling +from botcore.utils import ( +    _monkey_patches, +    caching, +    channel, +    commands, +    cooldown, +    function, +    interactions, +    logging, +    members, +    regex, +    scheduling, +)  from botcore.utils._extensions import unqualify @@ -24,8 +36,10 @@ __all__ = [      apply_monkey_patches,      caching,      channel, +    commands,      cooldown,      function, +    interactions,      logging,      members,      regex, diff --git a/botcore/utils/_monkey_patches.py b/botcore/utils/_monkey_patches.py index f2c6c100..c2f8aa10 100644 --- a/botcore/utils/_monkey_patches.py +++ b/botcore/utils/_monkey_patches.py @@ -1,6 +1,7 @@  """Contains all common monkey patches, used to alter discord to fit our needs."""  import logging +import typing  from datetime import datetime, timedelta  from functools import partial, partialmethod @@ -46,9 +47,9 @@ def _patch_typing() -> None:      log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!")      original = http.HTTPClient.send_typing -    last_403 = None +    last_403: typing.Optional[datetime] = None -    async def honeybadger_type(self, channel_id: int) -> None:  # noqa: ANN001 +    async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None:          nonlocal last_403          if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5):              log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") diff --git a/botcore/utils/commands.py b/botcore/utils/commands.py new file mode 100644 index 00000000..7afd8137 --- /dev/null +++ b/botcore/utils/commands.py @@ -0,0 +1,38 @@ +from typing import Optional + +from discord import Message +from discord.ext.commands import BadArgument, Context, clean_content + + +async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str: +    """ +    Cleans a text argument or replied message's content. + +    Args: +        ctx: The command's context +        text: The provided text argument of the command (if given) + +    Raises: +        :exc:`discord.ext.commands.BadArgument` +            `text` wasn't provided and there's no reply message / reply message content. + +    Returns: +         The cleaned version of `text`, if given, else replied message. +    """ +    clean_content_converter = clean_content(fix_channel_mentions=True) + +    if text: +        return await clean_content_converter.convert(ctx, text) + +    if ( +        (replied_message := getattr(ctx.message.reference, "resolved", None))  # message has a cached reference +        and isinstance(replied_message, Message)  # referenced message hasn't been deleted +    ): +        if not (content := ctx.message.reference.resolved.content): +            # The referenced message doesn't have a content (e.g. embed/image), so raise error +            raise BadArgument("The referenced message doesn't have a text content.") + +        return await clean_content_converter.convert(ctx, content) + +    # No text provided, and either no message was referenced or we can't access the content +    raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.") diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py index b9149b48..ee65033d 100644 --- a/botcore/utils/cooldown.py +++ b/botcore/utils/cooldown.py @@ -7,10 +7,9 @@ import random  import time  import typing  import weakref -from collections.abc import Awaitable, Hashable, Iterable +from collections.abc import Awaitable, Callable, Hashable, Iterable  from contextlib import suppress  from dataclasses import dataclass -from typing import Callable  # sphinx-autodoc-typehints breaks with collections.abc.Callable  import discord  from discord.ext.commands import CommandError, Context diff --git a/botcore/utils/function.py b/botcore/utils/function.py index e8d24e90..0e90d4c5 100644 --- a/botcore/utils/function.py +++ b/botcore/utils/function.py @@ -5,8 +5,7 @@ from __future__ import annotations  import functools  import types  import typing -from collections.abc import Sequence, Set -from typing import Callable  # sphinx-autodoc-typehints breaks with collections.abc.Callable +from collections.abc import Callable, Sequence, Set  __all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] diff --git a/botcore/utils/interactions.py b/botcore/utils/interactions.py new file mode 100644 index 00000000..26bd92f2 --- /dev/null +++ b/botcore/utils/interactions.py @@ -0,0 +1,98 @@ +import contextlib +from typing import Optional, Sequence + +from discord import ButtonStyle, Interaction, Message, NotFound, ui + +from botcore.utils.logging import get_logger + +log = get_logger(__name__) + + +class ViewWithUserAndRoleCheck(ui.View): +    """ +    A view that allows the original invoker and moderators to interact with it. + +    Args: +        allowed_users: A sequence of user's ids who are allowed to interact with the view. +        allowed_roles: A sequence of role ids that are allowed to interact with the view. +        timeout: Timeout in seconds from last interaction with the UI before no longer accepting input. +            If ``None`` then there is no timeout. +        message: The message to remove the view from on timeout. This can also be set with +            ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated. +    """ + +    def __init__( +        self, +        *, +        allowed_users: Sequence[int], +        allowed_roles: Sequence[int], +        timeout: Optional[float] = 180.0, +        message: Optional[Message] = None +    ) -> None: +        super().__init__(timeout=timeout) +        self.allowed_users = allowed_users +        self.allowed_roles = allowed_roles +        self.message = message + +    async def interaction_check(self, interaction: Interaction) -> bool: +        """ +        Ensure the user clicking the button is the view invoker, or a moderator. + +        Args: +            interaction: The interaction that occurred. +        """ +        if interaction.user.id in self.allowed_users: +            log.trace( +                "Allowed interaction by %s (%d) on %d as they are an allowed user.", +                interaction.user, +                interaction.user.id, +                interaction.message.id, +            ) +            return True + +        if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): +            log.trace( +                "Allowed interaction by %s (%d)on %d as they have an allowed role.", +                interaction.user, +                interaction.user.id, +                interaction.message.id, +            ) +            return True + +        await interaction.response.send_message("This is not your button to click!", ephemeral=True) +        return False + +    async def on_timeout(self) -> None: +        """Remove the view from ``self.message`` if set.""" +        if self.message: +            with contextlib.suppress(NotFound): +                # Cover the case where this message has already been deleted by external means +                await self.message.edit(view=None) + + +class DeleteMessageButton(ui.Button): +    """ +    A button that can be added to a view to delete the message containing the view on click. + +    This button itself carries out no interaction checks, these should be done by the parent view. + +    See :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks. + +    Args: +        style (:literal-url:`ButtonStyle <https://discordpy.readthedocs.io/en/latest/interactions/api.html#discord.ButtonStyle>`): +            The style of the button, set to ``ButtonStyle.secondary`` if not specified. +        label: The label of the button, set to "Delete" if not specified. +    """  # noqa: E501 + +    def __init__( +        self, +        *, +        style: ButtonStyle = ButtonStyle.secondary, +        label: str = "Delete", +        **kwargs +    ): +        super().__init__(style=style, label=label, **kwargs) + +    async def callback(self, interaction: Interaction) -> None: +        """Delete the original message on button click.""" +        await interaction.message.delete() diff --git a/botcore/utils/members.py b/botcore/utils/members.py index e89b4618..1536a8d1 100644 --- a/botcore/utils/members.py +++ b/botcore/utils/members.py @@ -1,6 +1,6 @@  """Useful helper functions for interactin with :obj:`discord.Member` objects.""" -  import typing +from collections import abc  import discord @@ -30,18 +30,19 @@ async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Op  async def handle_role_change(      member: discord.Member, -    coro: typing.Callable[..., typing.Coroutine], +    coro: typing.Callable[[discord.Role], abc.Coroutine],      role: discord.Role  ) -> None:      """ -    Await the given ``coro`` with ``member`` as the sole argument. +    Await the given ``coro`` with ``role`` as the sole argument.      Handle errors that we expect to be raised from      :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`.      Args: -        member: The member to pass to ``coro``. +        member: The member that is being modified for logging purposes.          coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`. +        role: The role to be passed to ``coro``.      """      try:          await coro(role) diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py index 56c50dad..de82a1ed 100644 --- a/botcore/utils/regex.py +++ b/botcore/utils/regex.py @@ -3,6 +3,7 @@  import re  DISCORD_INVITE = re.compile( +    r"(https?://)?(www\.)?"                      # Optional http(s) and www.      r"(discord([.,]|dot)gg|"                     # Could be discord.gg/      r"discord([.,]|dot)com(/|slash)invite|"      # or discord.com/invite/      r"discordapp([.,]|dot)com(/|slash)invite|"   # or discordapp.com/invite/ @@ -10,7 +11,7 @@ DISCORD_INVITE = re.compile(      r"discord([.,]|dot)li|"                      # or discord.li      r"discord([.,]|dot)io|"                      # or discord.io.      r"((?<!\w)([.,]|dot))gg"                     # or .gg/ -    r")([/]|slash)"                              # / or 'slash' +    r")(/|slash)"                                # / or 'slash'      r"(?P<invite>\S+)",                          # the invite code itself      flags=re.IGNORECASE  ) @@ -32,7 +33,7 @@ FORMATTED_CODE_REGEX = re.compile(      r"(?P<code>.*?)"                        # extract all code inside the markup      r"\s*"                                  # any more whitespace before the end of the code markup      r"(?P=delim)",                          # match the exact same delimiter from the start again -    re.DOTALL | re.IGNORECASE               # "." also matches newlines, case insensitive +    flags=re.DOTALL | re.IGNORECASE         # "." also matches newlines, case insensitive  )  """  Regex for formatted code, using Discord's code blocks. @@ -44,7 +45,7 @@ RAW_CODE_REGEX = re.compile(      r"^(?:[ \t]*\n)*"                       # any blank (empty or tabs/spaces only) lines before the code      r"(?P<code>.*?)"                        # extract all the rest as code      r"\s*$",                                # any trailing whitespace until the end of the string -    re.DOTALL                               # "." also matches newlines +    flags=re.DOTALL                         # "." also matches newlines  )  """  Regex for raw code, *not* using Discord's code blocks. diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py index 164f6b10..9517df6d 100644 --- a/botcore/utils/scheduling.py +++ b/botcore/utils/scheduling.py @@ -4,6 +4,7 @@ import asyncio  import contextlib  import inspect  import typing +from collections import abc  from datetime import datetime  from functools import partial @@ -38,9 +39,9 @@ class Scheduler:          self.name = name          self._log = logging.get_logger(f"{__name__}.{name}") -        self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {} +        self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {} -    def __contains__(self, task_id: typing.Hashable) -> bool: +    def __contains__(self, task_id: abc.Hashable) -> bool:          """          Return :obj:`True` if a task with the given ``task_id`` is currently scheduled. @@ -52,7 +53,7 @@ class Scheduler:          """          return task_id in self._scheduled_tasks -    def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: +    def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:          """          Schedule the execution of a ``coroutine``. @@ -79,7 +80,7 @@ class Scheduler:          self._scheduled_tasks[task_id] = task          self._log.debug(f"Scheduled task #{task_id} {id(task)}.") -    def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: +    def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:          """          Schedule ``coroutine`` to be executed at the given ``time``. @@ -106,8 +107,8 @@ class Scheduler:      def schedule_later(          self,          delay: typing.Union[int, float], -        task_id: typing.Hashable, -        coroutine: typing.Coroutine +        task_id: abc.Hashable, +        coroutine: abc.Coroutine      ) -> None:          """          Schedule ``coroutine`` to be executed after ``delay`` seconds. @@ -122,7 +123,7 @@ class Scheduler:          """          self.schedule(task_id, self._await_later(delay, task_id, coroutine)) -    def cancel(self, task_id: typing.Hashable) -> None: +    def cancel(self, task_id: abc.Hashable) -> None:          """          Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist. @@ -150,8 +151,8 @@ class Scheduler:      async def _await_later(          self,          delay: typing.Union[int, float], -        task_id: typing.Hashable, -        coroutine: typing.Coroutine +        task_id: abc.Hashable, +        coroutine: abc.Coroutine      ) -> None:          """Await ``coroutine`` after ``delay`` seconds."""          try: @@ -173,7 +174,7 @@ class Scheduler:              else:                  self._log.debug(f"Finally block reached for #{task_id}; {state=}") -    def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None: +    def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None:          """          Delete the task and raise its exception if one exists. @@ -208,13 +209,16 @@ class Scheduler:                  self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) +TASK_RETURN = typing.TypeVar("TASK_RETURN") + +  def create_task( -    coro: typing.Awaitable, +    coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN],      *, -    suppressed_exceptions: tuple[typing.Type[Exception]] = (), +    suppressed_exceptions: tuple[type[Exception], ...] = (),      event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,      **kwargs, -) -> asyncio.Task: +) -> asyncio.Task[TASK_RETURN]:      """      Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task. @@ -238,7 +242,7 @@ def create_task(      return task -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None: +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None:      """Retrieve and log the exception raised in ``task`` if one exists."""      with contextlib.suppress(asyncio.CancelledError):          exception = task.exception() | 
