diff options
author | 2022-11-05 14:15:11 +0000 | |
---|---|---|
committer | 2022-11-05 14:15:11 +0000 | |
commit | 3f55e7149a3197b7fa41fcf7dc7df47a3a209cfd (patch) | |
tree | 11d77db3a48d8226bc1cd14cb18a21d2423b6352 /pydis_core/utils | |
parent | Use New Static Build Site API (#122) (diff) | |
parent | Add six as a dev dep (diff) |
Merge pull request #157 from python-discord/prepare-for-pypi-releasev9.0.0
Prepare for pypi release
Diffstat (limited to 'pydis_core/utils')
-rw-r--r-- | pydis_core/utils/__init__.py | 50 | ||||
-rw-r--r-- | pydis_core/utils/_extensions.py | 57 | ||||
-rw-r--r-- | pydis_core/utils/_monkey_patches.py | 73 | ||||
-rw-r--r-- | pydis_core/utils/caching.py | 65 | ||||
-rw-r--r-- | pydis_core/utils/channel.py | 54 | ||||
-rw-r--r-- | pydis_core/utils/commands.py | 38 | ||||
-rw-r--r-- | pydis_core/utils/cooldown.py | 220 | ||||
-rw-r--r-- | pydis_core/utils/function.py | 111 | ||||
-rw-r--r-- | pydis_core/utils/interactions.py | 98 | ||||
-rw-r--r-- | pydis_core/utils/logging.py | 51 | ||||
-rw-r--r-- | pydis_core/utils/members.py | 57 | ||||
-rw-r--r-- | pydis_core/utils/regex.py | 54 | ||||
-rw-r--r-- | pydis_core/utils/scheduling.py | 252 |
13 files changed, 1180 insertions, 0 deletions
diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py new file mode 100644 index 00000000..0542231e --- /dev/null +++ b/pydis_core/utils/__init__.py @@ -0,0 +1,50 @@ +"""Useful utilities and tools for Discord bot development.""" + +from pydis_core.utils import ( + _monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, +) +from pydis_core.utils._extensions import unqualify + + +def apply_monkey_patches() -> None: + """ + Applies all common monkey patches for our bots. + + Patches :obj:`discord.ext.commands.Command` and :obj:`discord.ext.commands.Group` to support root aliases. + A ``root_aliases`` keyword argument is added to these two objects, which is a sequence of alias names + that will act as top-level groups rather than being aliases of the command's group. + + It's stored as an attribute also named ``root_aliases`` + + Patches discord's internal ``send_typing`` method so that it ignores 403 errors from Discord. + When under heavy load Discord has added a CloudFlare worker to this route, which causes 403 errors to be thrown. + """ + _monkey_patches._apply_monkey_patches() + + +__all__ = [ + apply_monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, + unqualify, +] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/pydis_core/utils/_extensions.py b/pydis_core/utils/_extensions.py new file mode 100644 index 00000000..536a0715 --- /dev/null +++ b/pydis_core/utils/_extensions.py @@ -0,0 +1,57 @@ +"""Utilities for loading Discord extensions.""" + +import importlib +import inspect +import pkgutil +import types +from typing import NoReturn + + +def unqualify(name: str) -> str: + """ + Return an unqualified name given a qualified module/package ``name``. + + Args: + name: The module name to unqualify. + + Returns: + The unqualified module name. + """ + return name.rsplit(".", maxsplit=1)[-1] + + +def ignore_module(module: pkgutil.ModuleInfo) -> bool: + """Return whether the module with name `name` should be ignored.""" + return any(name.startswith("_") for name in module.name.split(".")) + + +def walk_extensions(module: types.ModuleType) -> frozenset[str]: + """ + Return all extension names from the given module. + + Args: + module (types.ModuleType): The module to look for extensions in. + + Returns: + A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. + """ + + def on_error(name: str) -> NoReturn: + raise ImportError(name=name) # pragma: no cover + + modules = set() + + for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): + if ignore_module(module_info): + # Ignore modules/packages that have a name starting with an underscore anywhere in their trees. + continue + + if module_info.ispkg: + imported = importlib.import_module(module_info.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + modules.add(module_info.name) + + return frozenset(modules) diff --git a/pydis_core/utils/_monkey_patches.py b/pydis_core/utils/_monkey_patches.py new file mode 100644 index 00000000..f0a8dc9c --- /dev/null +++ b/pydis_core/utils/_monkey_patches.py @@ -0,0 +1,73 @@ +"""Contains all common monkey patches, used to alter discord to fit our needs.""" + +import logging +import typing +from datetime import datetime, timedelta +from functools import partial, partialmethod + +from discord import Forbidden, http +from discord.ext import commands + +log = logging.getLogger(__name__) + + +class _Command(commands.Command): + """ + A :obj:`discord.ext.commands.Command` subclass which supports root aliases. + + A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as + top-level commands rather than being aliases of the command's group. It's stored as an attribute + also named ``root_aliases``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.root_aliases = kwargs.get("root_aliases", []) + + if not isinstance(self.root_aliases, (list, tuple)): + raise TypeError("Root aliases of a command must be a list or a tuple of strings.") + + +class _Group(commands.Group, _Command): + """ + A :obj:`discord.ext.commands.Group` subclass which supports root aliases. + + A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as + top-level groups rather than being aliases of the command's group. It's stored as an attribute + also named ``root_aliases``. + """ + + +def _patch_typing() -> None: + """ + Sometimes Discord turns off typing events by throwing 403s. + + Handle those issues by patching discord's internal ``send_typing`` method so it ignores 403s in general. + """ + log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!") + + original = http.HTTPClient.send_typing + last_403: typing.Optional[datetime] = None + + async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None: + nonlocal last_403 + if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5): + log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") + return + try: + await original(self, channel_id) + except Forbidden: + last_403 = datetime.utcnow() + log.warning("Got a 403 from typing event!") + + http.HTTPClient.send_typing = honeybadger_type + + +def _apply_monkey_patches() -> None: + """This is surfaced directly in pydis_core.utils.apply_monkey_patches().""" + commands.command = partial(commands.command, cls=_Command) + commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=_Command) + + commands.group = partial(commands.group, cls=_Group) + commands.GroupMixin.group = partialmethod(commands.GroupMixin.group, cls=_Group) + _patch_typing() diff --git a/pydis_core/utils/caching.py b/pydis_core/utils/caching.py new file mode 100644 index 00000000..ac34bb9b --- /dev/null +++ b/pydis_core/utils/caching.py @@ -0,0 +1,65 @@ +"""Utilities related to custom caches.""" + +import functools +import typing +from collections import OrderedDict + + +class AsyncCache: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + + def __init__(self, max_size: int = 128): + """ + Initialise a new :obj:`AsyncCache` instance. + + Args: + max_size: How many items to store in the cache. + """ + self._cache = OrderedDict() + self._max_size = max_size + + def __call__(self, arg_offset: int = 0) -> typing.Callable: + """ + Decorator for async cache. + + Args: + arg_offset: The offset for the position of the key argument. + + Returns: + A decorator to wrap the target function. + """ + + def decorator(function: typing.Callable) -> typing.Callable: + """ + Define the async cache decorator. + + Args: + function: The function to wrap. + + Returns: + The wrapped function. + """ + + @functools.wraps(function) + async def wrapper(*args) -> typing.Any: + """Decorator wrapper for the caching logic.""" + key = args[arg_offset:] + + if key not in self._cache: + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + self._cache[key] = await function(*args) + return self._cache[key] + return wrapper + return decorator + + def clear(self) -> None: + """Clear cache instance.""" + self._cache.clear() diff --git a/pydis_core/utils/channel.py b/pydis_core/utils/channel.py new file mode 100644 index 00000000..854c64fd --- /dev/null +++ b/pydis_core/utils/channel.py @@ -0,0 +1,54 @@ +"""Useful helper functions for interacting with various discord channel objects.""" + +import discord +from discord.ext.commands import Bot + +from pydis_core.utils import logging + +log = logging.get_logger(__name__) + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: + """ + Return whether the given ``channel`` in the the category with the id ``category_id``. + + Args: + channel: The channel to check. + category_id: The category to check for. + + Returns: + A bool depending on whether the channel is in the category. + """ + return getattr(channel, "category_id", None) == category_id + + +async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: + """ + Attempt to get or fetch the given ``channel_id`` from the bots cache, and return it. + + Args: + bot: The :obj:`discord.ext.commands.Bot` instance to use for getting/fetching. + channel_id: The channel to get/fetch. + + Raises: + :exc:`discord.InvalidData` + An unknown channel type was received from Discord. + :exc:`discord.HTTPException` + Retrieving the channel failed. + :exc:`discord.NotFound` + Invalid Channel ID. + :exc:`discord.Forbidden` + You do not have permission to fetch this channel. + + Returns: + The channel from the ID. + """ + log.trace(f"Getting the channel {channel_id}.") + + channel = bot.get_channel(channel_id) + if not channel: + log.debug(f"Channel {channel_id} is not in cache; fetching from API.") + channel = await bot.fetch_channel(channel_id) + + log.trace(f"Channel #{channel} ({channel_id}) retrieved.") + return channel diff --git a/pydis_core/utils/commands.py b/pydis_core/utils/commands.py new file mode 100644 index 00000000..7afd8137 --- /dev/null +++ b/pydis_core/utils/commands.py @@ -0,0 +1,38 @@ +from typing import Optional + +from discord import Message +from discord.ext.commands import BadArgument, Context, clean_content + + +async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str: + """ + Cleans a text argument or replied message's content. + + Args: + ctx: The command's context + text: The provided text argument of the command (if given) + + Raises: + :exc:`discord.ext.commands.BadArgument` + `text` wasn't provided and there's no reply message / reply message content. + + Returns: + The cleaned version of `text`, if given, else replied message. + """ + clean_content_converter = clean_content(fix_channel_mentions=True) + + if text: + return await clean_content_converter.convert(ctx, text) + + if ( + (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference + and isinstance(replied_message, Message) # referenced message hasn't been deleted + ): + if not (content := ctx.message.reference.resolved.content): + # The referenced message doesn't have a content (e.g. embed/image), so raise error + raise BadArgument("The referenced message doesn't have a text content.") + + return await clean_content_converter.convert(ctx, content) + + # No text provided, and either no message was referenced or we can't access the content + raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.") diff --git a/pydis_core/utils/cooldown.py b/pydis_core/utils/cooldown.py new file mode 100644 index 00000000..5129befd --- /dev/null +++ b/pydis_core/utils/cooldown.py @@ -0,0 +1,220 @@ +"""Helpers for setting a cooldown on commands.""" + +from __future__ import annotations + +import asyncio +import random +import time +import typing +import weakref +from collections.abc import Awaitable, Callable, Hashable, Iterable +from contextlib import suppress +from dataclasses import dataclass + +import discord +from discord.ext.commands import CommandError, Context + +from pydis_core.utils import scheduling +from pydis_core.utils.function import command_wraps + +__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] + +_KEYWORD_SEP_SENTINEL = object() + +_ArgsList = list[object] +_HashableArgsTuple = tuple[Hashable, ...] + +if typing.TYPE_CHECKING: + import typing_extensions + from pydis_core import BotBase + +P = typing.ParamSpec("P") +"""The command's signature.""" +R = typing.TypeVar("R") +"""The command's return value.""" + + +class CommandOnCooldown(CommandError, typing.Generic[P, R]): + """Raised when a command is invoked while on cooldown.""" + + def __init__( + self, + message: str | None, + function: Callable[P, Awaitable[R]], + /, + *args: P.args, + **kwargs: P.kwargs, + ): + super().__init__(message, function, args, kwargs) + self._function = function + self._args = args + self._kwargs = kwargs + + async def call_without_cooldown(self) -> R: + """ + Run the command this cooldown blocked. + + Returns: + The command's return value. + """ + return await self._function(*self._args, **self._kwargs) + + +@dataclass +class _CooldownItem: + non_hashable_arguments: _ArgsList + timeout_timestamp: float + + +@dataclass +class _SeparatedArguments: + """Arguments separated into their hashable and non-hashable parts.""" + + hashable: _HashableArgsTuple + non_hashable: _ArgsList + + @classmethod + def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self: + """Create a new instance from full call arguments.""" + hashable = list[Hashable]() + non_hashable = list[object]() + + for item in call_arguments: + try: + hash(item) + except TypeError: + non_hashable.append(item) + else: + hashable.append(item) + + return cls(tuple(hashable), non_hashable) + + +class _CommandCooldownManager: + """ + Manage invocation cooldowns for a command through the arguments the command is called with. + + Use `set_cooldown` to set a cooldown, + and `is_on_cooldown` to check for a cooldown for a channel with the given arguments. + A cooldown lasts for `cooldown_duration` seconds. + """ + + def __init__(self, *, cooldown_duration: float): + self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]() + self._cooldown_duration = cooldown_duration + self.cleanup_task = scheduling.create_task( + self._periodical_cleanup(random.uniform(0, 10)), + name="CooldownManager cleanup", + ) + weakref.finalize(self, self.cleanup_task.cancel) + + def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None: + """Set `call_arguments` arguments on cooldown in `channel`.""" + timeout_timestamp = time.monotonic() + self._cooldown_duration + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.setdefault( + (channel, separated_arguments.hashable), + [], + ) + + for item in cooldowns_list: + if item.non_hashable_arguments == separated_arguments.non_hashable: + item.timeout_timestamp = timeout_timestamp + return + + cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp)) + + def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool: + """Check whether `call_arguments` is on cooldown in `channel`.""" + current_time = time.monotonic() + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.get( + (channel, separated_arguments.hashable), + [], + ) + + for item in cooldowns_list: + if item.non_hashable_arguments == separated_arguments.non_hashable: + return item.timeout_timestamp > current_time + return False + + async def _periodical_cleanup(self, initial_delay: float) -> None: + """ + Delete stale items every hour after waiting for `initial_delay`. + + The `initial_delay` ensures cleanups are not running for every command at the same time. + A strong reference to self is only kept while cleanup is running. + """ + weak_self = weakref.ref(self) + del self + + await asyncio.sleep(initial_delay) + while True: + await asyncio.sleep(60 * 60) + weak_self()._delete_stale_items() + + def _delete_stale_items(self) -> None: + """Remove expired items from internal collections.""" + current_time = time.monotonic() + + for key, cooldowns_list in self._cooldowns.copy().items(): + filtered_cooldowns = [ + cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time + ] + + if not filtered_cooldowns: + del self._cooldowns[key] + else: + self._cooldowns[key] = filtered_cooldowns + + +def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]: + return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items()) + + +def block_duplicate_invocations( + *, + cooldown_duration: float = 5, + send_notice: bool = False, + args_preprocessor: Callable[P, Iterable[object]] | None = None, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """ + Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds. + + Args: + cooldown_duration: Length of the cooldown in seconds. + send_notice: If :obj:`True`, notify the user about the cooldown with a reply. + args_preprocessor: If specified, this function is called with the args and kwargs the function is called with, + its return value is then used to check for the cooldown instead of the raw arguments. + + Returns: + A decorator that adds a wrapper which applies the cooldowns. + + Warning: + The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown. + """ + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration) + + @command_wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if args_preprocessor is not None: + all_args = args_preprocessor(*args, **kwargs) + else: + all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command + ctx = typing.cast("Context[BotBase]", args[1]) + + if not isinstance(ctx.channel, discord.DMChannel): + if mgr.is_on_cooldown(ctx.channel, all_args): + if send_notice: + with suppress(discord.NotFound): + await ctx.reply("The command is on cooldown with the given arguments.") + raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs) + mgr.set_cooldown(ctx.channel, all_args) + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py new file mode 100644 index 00000000..d89163ec --- /dev/null +++ b/pydis_core/utils/function.py @@ -0,0 +1,111 @@ +"""Utils for manipulating functions.""" + +from __future__ import annotations + +import functools +import types +import typing +from collections.abc import Callable, Sequence, Set + +__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] + + +if typing.TYPE_CHECKING: + _P = typing.ParamSpec("_P") + _R = typing.TypeVar("_R") + + +class GlobalNameConflictError(Exception): + """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" + + +def update_wrapper_globals( + wrapper: Callable[_P, _R], + wrapped: Callable[_P, _R], + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[_P, _R]: + r""" + Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals. + + For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function + to resolve their values. This breaks for decorators that replace the function because they have + their own globals. + + .. warning:: + This function captures the state of ``wrapped``\'s module's globals when it's called; + changes won't be reflected in the new function's globals. + + Args: + wrapper: The function to wrap. + wrapped: The function to wrap with. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Raises: + :exc:`GlobalNameConflictError`: + If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, + and is not in ``ignored_conflict_names``. + """ + wrapped = typing.cast(types.FunctionType, wrapped) + wrapper = typing.cast(types.FunctionType, wrapper) + + annotation_global_names = ( + ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) + ) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = ( + set(wrapper.__code__.co_names) + & set(annotation_global_names) + & set(wrapped.__globals__) + & set(wrapper.__globals__) + - ignored_conflict_names + ) + if shared_globals: + raise GlobalNameConflictError( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " + f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." + ) + + new_globals = wrapper.__globals__.copy() + new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) + return types.FunctionType( + code=wrapper.__code__, + globals=new_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) + + +def command_wraps( + wrapped: Callable[_P, _R], + assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, + updated: Sequence[str] = functools.WRAPPER_UPDATES, + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + r""" + Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. + + See :func:`update_wrapper_globals` for more details on how the globals are updated. + + Args: + wrapped: The function to wrap with. + assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. + updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Returns: + A decorator that behaves like :func:`functools.wraps`, + with the wrapper replaced with the function :func:`update_wrapper_globals` returned. + """ # noqa: D200 + def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: + return functools.update_wrapper( + update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), + wrapped, + assigned, + updated, + ) + + return decorator diff --git a/pydis_core/utils/interactions.py b/pydis_core/utils/interactions.py new file mode 100644 index 00000000..3e4acffe --- /dev/null +++ b/pydis_core/utils/interactions.py @@ -0,0 +1,98 @@ +import contextlib +from typing import Optional, Sequence + +from discord import ButtonStyle, Interaction, Message, NotFound, ui + +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) + + +class ViewWithUserAndRoleCheck(ui.View): + """ + A view that allows the original invoker and moderators to interact with it. + + Args: + allowed_users: A sequence of user's ids who are allowed to interact with the view. + allowed_roles: A sequence of role ids that are allowed to interact with the view. + timeout: Timeout in seconds from last interaction with the UI before no longer accepting input. + If ``None`` then there is no timeout. + message: The message to remove the view from on timeout. This can also be set with + ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated. + """ + + def __init__( + self, + *, + allowed_users: Sequence[int], + allowed_roles: Sequence[int], + timeout: Optional[float] = 180.0, + message: Optional[Message] = None + ) -> None: + super().__init__(timeout=timeout) + self.allowed_users = allowed_users + self.allowed_roles = allowed_roles + self.message = message + + async def interaction_check(self, interaction: Interaction) -> bool: + """ + Ensure the user clicking the button is the view invoker, or a moderator. + + Args: + interaction: The interaction that occurred. + """ + if interaction.user.id in self.allowed_users: + log.trace( + "Allowed interaction by %s (%d) on %d as they are an allowed user.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): + log.trace( + "Allowed interaction by %s (%d)on %d as they have an allowed role.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + await interaction.response.send_message("This is not your button to click!", ephemeral=True) + return False + + async def on_timeout(self) -> None: + """Remove the view from ``self.message`` if set.""" + if self.message: + with contextlib.suppress(NotFound): + # Cover the case where this message has already been deleted by external means + await self.message.edit(view=None) + + +class DeleteMessageButton(ui.Button): + """ + A button that can be added to a view to delete the message containing the view on click. + + This button itself carries out no interaction checks, these should be done by the parent view. + + See :obj:`pydis_core.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks. + + Args: + style (:literal-url:`ButtonStyle <https://discordpy.readthedocs.io/en/latest/interactions/api.html#discord.ButtonStyle>`): + The style of the button, set to ``ButtonStyle.secondary`` if not specified. + label: The label of the button, set to "Delete" if not specified. + """ # noqa: E501 + + def __init__( + self, + *, + style: ButtonStyle = ButtonStyle.secondary, + label: str = "Delete", + **kwargs + ): + super().__init__(style=style, label=label, **kwargs) + + async def callback(self, interaction: Interaction) -> None: + """Delete the original message on button click.""" + await interaction.message.delete() diff --git a/pydis_core/utils/logging.py b/pydis_core/utils/logging.py new file mode 100644 index 00000000..7814f348 --- /dev/null +++ b/pydis_core/utils/logging.py @@ -0,0 +1,51 @@ +"""Common logging related functions.""" + +import logging +import typing + +if typing.TYPE_CHECKING: + LoggerClass = logging.Logger +else: + LoggerClass = logging.getLoggerClass() + +TRACE_LEVEL = 5 + + +class CustomLogger(LoggerClass): + """Custom implementation of the :obj:`logging.Logger` class with an added :obj:`trace` method.""" + + def trace(self, msg: str, *args, **kwargs) -> None: + """ + Log the given message with the severity ``"TRACE"``. + + To pass exception information, use the keyword argument exc_info with a true value: + + .. code-block:: py + + logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) + + Args: + msg: The message to be logged. + args, kwargs: Passed to the base log function as is. + """ + if self.isEnabledFor(TRACE_LEVEL): + self.log(TRACE_LEVEL, msg, *args, **kwargs) + + +def get_logger(name: typing.Optional[str] = None) -> CustomLogger: + """ + Utility to make mypy recognise that logger is of type :obj:`CustomLogger`. + + Args: + name: The name given to the logger. + + Returns: + An instance of the :obj:`CustomLogger` class. + """ + return typing.cast(CustomLogger, logging.getLogger(name)) + + +# Setup trace level logging so that we can use it within pydis_core. +logging.TRACE = TRACE_LEVEL +logging.setLoggerClass(CustomLogger) +logging.addLevelName(TRACE_LEVEL, "TRACE") diff --git a/pydis_core/utils/members.py b/pydis_core/utils/members.py new file mode 100644 index 00000000..b6eacc88 --- /dev/null +++ b/pydis_core/utils/members.py @@ -0,0 +1,57 @@ +"""Useful helper functions for interactin with :obj:`discord.Member` objects.""" +import typing +from collections import abc + +import discord + +from pydis_core.utils import logging + +log = logging.get_logger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: + """ + Attempt to get a member from cache; on failure fetch from the API. + + Returns: + The :obj:`discord.Member` or :obj:`None` to indicate the member could not be found. + """ + if member := guild.get_member(member_id): + log.trace(f"{member} retrieved from cache.") + else: + try: + member = await guild.fetch_member(member_id) + except discord.errors.NotFound: + log.trace(f"Failed to fetch {member_id} from API.") + return None + log.trace(f"{member} fetched from API.") + return member + + +async def handle_role_change( + member: discord.Member, + coro: typing.Callable[[discord.Role], abc.Coroutine], + role: discord.Role +) -> None: + """ + Await the given ``coro`` with ``role`` as the sole argument. + + Handle errors that we expect to be raised from + :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`. + + Args: + member: The member that is being modified for logging purposes. + coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`. + role: The role to be passed to ``coro``. + """ + try: + await coro(role) + except discord.NotFound: + log.error(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.error( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/pydis_core/utils/regex.py b/pydis_core/utils/regex.py new file mode 100644 index 00000000..de82a1ed --- /dev/null +++ b/pydis_core/utils/regex.py @@ -0,0 +1,54 @@ +"""Common regular expressions.""" + +import re + +DISCORD_INVITE = re.compile( + r"(https?://)?(www\.)?" # Optional http(s) and www. + r"(discord([.,]|dot)gg|" # Could be discord.gg/ + r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ + r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ + r"discord([.,]|dot)me|" # or discord.me + r"discord([.,]|dot)li|" # or discord.li + r"discord([.,]|dot)io|" # or discord.io. + r"((?<!\w)([.,]|dot))gg" # or .gg/ + r")(/|slash)" # / or 'slash' + r"(?P<invite>\S+)", # the invite code itself + flags=re.IGNORECASE +) +""" +Regex for Discord server invites. + +.. warning:: + This regex pattern will capture until a whitespace, if you are to use the 'invite' capture group in + any HTTP requests or similar. Please ensure you sanitise the output using something + such as :func:`urllib.parse.quote`. + +:meta hide-value: +""" + +FORMATTED_CODE_REGEX = re.compile( + r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P<code>.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)", # match the exact same delimiter from the start again + flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +""" +Regex for formatted code, using Discord's code blocks. + +:meta hide-value: +""" + +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P<code>.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + flags=re.DOTALL # "." also matches newlines +) +""" +Regex for raw code, *not* using Discord's code blocks. + +:meta hide-value: +""" diff --git a/pydis_core/utils/scheduling.py b/pydis_core/utils/scheduling.py new file mode 100644 index 00000000..eced4a3d --- /dev/null +++ b/pydis_core/utils/scheduling.py @@ -0,0 +1,252 @@ +"""Generic python scheduler.""" + +import asyncio +import contextlib +import inspect +import typing +from collections import abc +from datetime import datetime +from functools import partial + +from pydis_core.utils import logging + + +class Scheduler: + """ + Schedule the execution of coroutines and keep track of them. + + When instantiating a :obj:`Scheduler`, a name must be provided. This name is used to distinguish the + instance's log messages from other instances. Using the name of the class or module containing + the instance is suggested. + + Coroutines can be scheduled immediately with :obj:`schedule` or in the future with :obj:`schedule_at` + or :obj:`schedule_later`. A unique ID is required to be given in order to keep track of the + resulting Tasks. Any scheduled task can be cancelled prematurely using :obj:`cancel` by providing + the same ID used to schedule it. + + The ``in`` operator is supported for checking if a task with a given ID is currently scheduled. + + Any exception raised in a scheduled task is logged when the task is done. + """ + + def __init__(self, name: str): + """ + Initialize a new :obj:`Scheduler` instance. + + Args: + name: The name of the :obj:`Scheduler`. Used in logging, and namespacing. + """ + self.name = name + + self._log = logging.get_logger(f"{__name__}.{name}") + self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {} + + def __contains__(self, task_id: abc.Hashable) -> bool: + """ + Return :obj:`True` if a task with the given ``task_id`` is currently scheduled. + + Args: + task_id: The task to look for. + + Returns: + :obj:`True` if the task was found. + """ + return task_id in self._scheduled_tasks + + def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: + """ + Schedule the execution of a ``coroutine``. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self._log.trace(f"Scheduling task #{task_id}...") + + msg = f"Cannot schedule an already started coroutine for #{task_id}" + assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg + + if task_id in self._scheduled_tasks: + self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") + coroutine.close() + return + + task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") + task.add_done_callback(partial(self._task_done_callback, task_id)) + + self._scheduled_tasks[task_id] = task + self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + + def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: + """ + Schedule ``coroutine`` to be executed at the given ``time``. + + If ``time`` is timezone aware, then use that timezone to calculate now() when subtracting. + If ``time`` is naïve, then use UTC. + + If ``time`` is in the past, schedule ``coroutine`` immediately. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + time: The time to start the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() + delay = (time - now_datetime).total_seconds() + if delay > 0: + coroutine = self._await_later(delay, task_id, coroutine) + + self.schedule(task_id, coroutine) + + def schedule_later( + self, + delay: typing.Union[int, float], + task_id: abc.Hashable, + coroutine: abc.Coroutine + ) -> None: + """ + Schedule ``coroutine`` to be executed after ``delay`` seconds. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + delay: How long to wait before starting the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self.schedule(task_id, self._await_later(delay, task_id, coroutine)) + + def cancel(self, task_id: abc.Hashable) -> None: + """ + Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist. + + Args: + task_id: The task's unique ID. + """ + self._log.trace(f"Cancelling task #{task_id}...") + + try: + task = self._scheduled_tasks.pop(task_id) + except KeyError: + self._log.warning(f"Failed to unschedule {task_id} (no task found).") + else: + task.cancel() + + self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") + + def cancel_all(self) -> None: + """Unschedule all known tasks.""" + self._log.debug("Unscheduling all tasks") + + for task_id in self._scheduled_tasks.copy(): + self.cancel(task_id) + + async def _await_later( + self, + delay: typing.Union[int, float], + task_id: abc.Hashable, + coroutine: abc.Coroutine + ) -> None: + """Await ``coroutine`` after ``delay`` seconds.""" + try: + self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") + await asyncio.sleep(delay) + + # Use asyncio.shield to prevent the coroutine from cancelling itself. + self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") + await asyncio.shield(coroutine) + finally: + # Close it to prevent unawaited coroutine warnings, + # which would happen if the task was cancelled during the sleep. + # Only close it if it's not been awaited yet. This check is important because the + # coroutine may cancel this task, which would also trigger the finally block. + state = inspect.getcoroutinestate(coroutine) + if state == "CORO_CREATED": + self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") + coroutine.close() + else: + self._log.debug(f"Finally block reached for #{task_id}; {state=}") + + def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None: + """ + Delete the task and raise its exception if one exists. + + If ``done_task`` and the task associated with ``task_id`` are different, then the latter + will not be deleted. In this case, a new task was likely rescheduled with the same ID. + """ + self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") + + scheduled_task = self._scheduled_tasks.get(task_id) + + if scheduled_task and done_task is scheduled_task: + # A task for the ID exists and is the same as the done task. + # Since this is the done callback, the task is already done so no need to cancel it. + self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") + del self._scheduled_tasks[task_id] + elif scheduled_task: + # A new task was likely rescheduled with the same ID. + self._log.debug( + f"The scheduled task #{task_id} {id(scheduled_task)} " + f"and the done task {id(done_task)} differ." + ) + elif not done_task.cancelled(): + self._log.warning( + f"Task #{task_id} not found while handling task {id(done_task)}! " + f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." + ) + + with contextlib.suppress(asyncio.CancelledError): + exception = done_task.exception() + # Log the exception if one exists. + if exception: + self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) + + +TASK_RETURN = typing.TypeVar("TASK_RETURN") + + +def create_task( + coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN], + *, + suppressed_exceptions: tuple[type[Exception], ...] = (), + event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, + **kwargs, +) -> asyncio.Task[TASK_RETURN]: + """ + Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task. + + If the ``event_loop`` kwarg is provided, the task is created from that event loop, + otherwise the running loop is used. + + Args: + coro: The function to call. + suppressed_exceptions: Exceptions to be handled by the task. + event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. + kwargs: Passed to :py:func:`asyncio.create_task`. + + Returns: + asyncio.Task: The wrapped task. + """ + if event_loop is not None: + task = event_loop.create_task(coro, **kwargs) + else: + task = asyncio.create_task(coro, **kwargs) + task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) + return task + + +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None: + """Retrieve and log the exception raised in ``task`` if one exists.""" + with contextlib.suppress(asyncio.CancelledError): + exception = task.exception() + # Log the exception if one exists. + if exception and not isinstance(exception, suppressed_exceptions): + log = logging.get_logger(__name__) + log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) |