summaryrefslogtreecommitdiffstats
path: root/pydis_core
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_core')
-rw-r--r--pydis_core/__init__.py15
-rw-r--r--pydis_core/_bot.py288
-rw-r--r--pydis_core/async_stats.py57
-rw-r--r--pydis_core/exts/__init__.py4
-rw-r--r--pydis_core/site_api.py157
-rw-r--r--pydis_core/utils/__init__.py50
-rw-r--r--pydis_core/utils/_extensions.py57
-rw-r--r--pydis_core/utils/_monkey_patches.py73
-rw-r--r--pydis_core/utils/caching.py65
-rw-r--r--pydis_core/utils/channel.py54
-rw-r--r--pydis_core/utils/commands.py38
-rw-r--r--pydis_core/utils/cooldown.py220
-rw-r--r--pydis_core/utils/function.py111
-rw-r--r--pydis_core/utils/interactions.py98
-rw-r--r--pydis_core/utils/logging.py51
-rw-r--r--pydis_core/utils/members.py57
-rw-r--r--pydis_core/utils/regex.py54
-rw-r--r--pydis_core/utils/scheduling.py252
18 files changed, 1701 insertions, 0 deletions
diff --git a/pydis_core/__init__.py b/pydis_core/__init__.py
new file mode 100644
index 00000000..a09feeaa
--- /dev/null
+++ b/pydis_core/__init__.py
@@ -0,0 +1,15 @@
+"""Useful utilities and tools for Discord bot development."""
+
+from pydis_core import async_stats, exts, site_api, utils
+from pydis_core._bot import BotBase, StartupError
+
+__all__ = [
+ async_stats,
+ BotBase,
+ exts,
+ utils,
+ site_api,
+ StartupError,
+]
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/pydis_core/_bot.py b/pydis_core/_bot.py
new file mode 100644
index 00000000..56814f27
--- /dev/null
+++ b/pydis_core/_bot.py
@@ -0,0 +1,288 @@
+import asyncio
+import socket
+import types
+import warnings
+from contextlib import suppress
+from typing import Optional
+
+import aiohttp
+import discord
+from discord.ext import commands
+
+from pydis_core.async_stats import AsyncStatsClient
+from pydis_core.site_api import APIClient
+from pydis_core.utils import scheduling
+from pydis_core.utils._extensions import walk_extensions
+from pydis_core.utils.logging import get_logger
+
+try:
+ from async_rediscache import RedisSession
+ REDIS_AVAILABLE = True
+except ImportError:
+ RedisSession = None
+ REDIS_AVAILABLE = False
+
+log = get_logger()
+
+
+class StartupError(Exception):
+ """Exception class for startup errors."""
+
+ def __init__(self, base: Exception):
+ super().__init__()
+ self.exception = base
+
+
+class BotBase(commands.Bot):
+ """A sub-class that implements many common features that Python Discord bots use."""
+
+ def __init__(
+ self,
+ *args,
+ guild_id: int,
+ allowed_roles: list,
+ http_session: aiohttp.ClientSession,
+ redis_session: Optional[RedisSession] = None,
+ api_client: Optional[APIClient] = None,
+ statsd_url: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Initialise the base bot instance.
+
+ Args:
+ guild_id: The ID of the guild used for :func:`wait_until_guild_available`.
+ allowed_roles: A list of role IDs that the bot is allowed to mention.
+ http_session (aiohttp.ClientSession): The session to use for the bot.
+ redis_session: The `async_rediscache.RedisSession`_ to use for the bot.
+ api_client: The :obj:`pydis_core.site_api.APIClient` instance to use for the bot.
+ statsd_url: The URL of the statsd server to use for the bot. If not given,
+ a dummy statsd client will be created.
+
+ .. _async_rediscache.RedisSession: https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession
+ """
+ super().__init__(
+ *args,
+ allowed_roles=allowed_roles,
+ **kwargs,
+ )
+
+ self.guild_id = guild_id
+ self.http_session = http_session
+ self.api_client = api_client
+ self.statsd_url = statsd_url
+
+ if redis_session and not REDIS_AVAILABLE:
+ warnings.warn("redis_session kwarg passed, but async-rediscache not installed!")
+ elif redis_session:
+ self.redis_session = redis_session
+
+ self._resolver: Optional[aiohttp.AsyncResolver] = None
+ self._connector: Optional[aiohttp.TCPConnector] = None
+
+ self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None
+ self._guild_available: Optional[asyncio.Event] = None
+
+ self.stats: Optional[AsyncStatsClient] = None
+
+ self.all_extensions: Optional[frozenset[str]] = None
+
+ def _connect_statsd(
+ self,
+ statsd_url: str,
+ loop: asyncio.AbstractEventLoop,
+ retry_after: int = 2,
+ attempt: int = 1
+ ) -> None:
+ """Callback used to retry a connection to statsd if it should fail."""
+ if attempt >= 8:
+ log.error(
+ "Reached 8 attempts trying to reconnect AsyncStatsClient to %s. "
+ "Aborting and leaving the dummy statsd client in place.",
+ statsd_url,
+ )
+ return
+
+ try:
+ self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot")
+ except socket.gaierror:
+ log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})")
+ # Use a fallback strategy for retrying, up to 8 times.
+ self._statsd_timerhandle = loop.call_later(
+ retry_after,
+ self._connect_statsd,
+ statsd_url,
+ retry_after * 2,
+ attempt + 1
+ )
+
+ async def load_extensions(self, module: types.ModuleType) -> None:
+ """
+ Load all the extensions within the given module and save them to ``self.all_extensions``.
+
+ This should be ran in a task on the event loop to avoid deadlocks caused by ``wait_for`` calls.
+ """
+ await self.wait_until_guild_available()
+ self.all_extensions = walk_extensions(module)
+
+ for extension in self.all_extensions:
+ scheduling.create_task(self.load_extension(extension))
+
+ def _add_root_aliases(self, command: commands.Command) -> None:
+ """Recursively add root aliases for ``command`` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._add_root_aliases(subcommand)
+
+ for alias in getattr(command, "root_aliases", ()):
+ if alias in self.all_commands:
+ raise commands.CommandRegistrationError(alias, alias_conflict=True)
+
+ self.all_commands[alias] = command
+
+ def _remove_root_aliases(self, command: commands.Command) -> None:
+ """Recursively remove root aliases for ``command`` and any of its subcommands."""
+ if isinstance(command, commands.Group):
+ for subcommand in command.commands:
+ self._remove_root_aliases(subcommand)
+
+ for alias in getattr(command, "root_aliases", ()):
+ self.all_commands.pop(alias, None)
+
+ async def add_cog(self, cog: commands.Cog) -> None:
+ """Add the given ``cog`` to the bot and log the operation."""
+ await super().add_cog(cog)
+ log.info(f"Cog loaded: {cog.qualified_name}")
+
+ def add_command(self, command: commands.Command) -> None:
+ """Add ``command`` as normal and then add its root aliases to the bot."""
+ super().add_command(command)
+ self._add_root_aliases(command)
+
+ def remove_command(self, name: str) -> Optional[commands.Command]:
+ """
+ Remove a command/alias as normal and then remove its root aliases from the bot.
+
+ Individual root aliases cannot be removed by this function.
+ To remove them, either remove the entire command or manually edit `bot.all_commands`.
+ """
+ command = super().remove_command(name)
+ if command is None:
+ # Even if it's a root alias, there's no way to get the Bot instance to remove the alias.
+ return None
+
+ self._remove_root_aliases(command)
+ return command
+
+ def clear(self) -> None:
+ """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one."""
+ raise NotImplementedError("Re-using a Bot object after closing it is not supported.")
+
+ async def on_guild_unavailable(self, guild: discord.Guild) -> None:
+ """Clear the internal guild available event when self.guild_id becomes unavailable."""
+ if guild.id != self.guild_id:
+ return
+
+ self._guild_available.clear()
+
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """
+ Set the internal guild available event when self.guild_id becomes available.
+
+ If the cache appears to still be empty (no members, no channels, or no roles), the event
+ will not be set and `guild_available_but_cache_empty` event will be emitted.
+ """
+ if guild.id != self.guild_id:
+ return
+
+ if not guild.roles or not guild.members or not guild.channels:
+ msg = "Guild available event was dispatched but the cache appears to still be empty!"
+ await self.log_to_dev_log(msg)
+ return
+
+ self._guild_available.set()
+
+ async def log_to_dev_log(self, message: str) -> None:
+ """Log the given message to #dev-log."""
+ ...
+
+ async def wait_until_guild_available(self) -> None:
+ """
+ Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready).
+
+ The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
+ gateway event before giving up and thus not populating the cache for unavailable guilds.
+ """
+ await self._guild_available.wait()
+
+ async def setup_hook(self) -> None:
+ """
+ An async init to startup generic services.
+
+ Connects to statsd, and calls
+ :func:`AsyncStatsClient.create_socket <pydis_core.async_stats.AsyncStatsClient.create_socket>`
+ and :func:`ping_services`.
+ """
+ loop = asyncio.get_running_loop()
+
+ self._guild_available = asyncio.Event()
+
+ self._resolver = aiohttp.AsyncResolver()
+ self._connector = aiohttp.TCPConnector(
+ resolver=self._resolver,
+ family=socket.AF_INET,
+ )
+ self.http.connector = self._connector
+
+ if getattr(self, "redis_session", False) and not self.redis_session.valid:
+ # If the RedisSession was somehow closed, we try to reconnect it
+ # here. Normally, this shouldn't happen.
+ await self.redis_session.connect(ping=True)
+
+ # Create dummy stats client first, in case `statsd_url` is unreachable or None
+ self.stats = AsyncStatsClient(loop, "127.0.0.1")
+ if self.statsd_url:
+ self._connect_statsd(self.statsd_url, loop)
+
+ await self.stats.create_socket()
+
+ try:
+ await self.ping_services()
+ except Exception as e:
+ raise StartupError(e)
+
+ async def ping_services(self) -> None:
+ """Ping all required services on setup to ensure they are up before starting."""
+ ...
+
+ async def close(self) -> None:
+ """Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver."""
+ # Done before super().close() to allow tasks finish before the HTTP session closes.
+ for ext in list(self.extensions):
+ with suppress(Exception):
+ await self.unload_extension(ext)
+
+ for cog in list(self.cogs):
+ with suppress(Exception):
+ await self.remove_cog(cog)
+
+ # Now actually do full close of bot
+ await super().close()
+
+ if self.api_client:
+ await self.api_client.close()
+
+ if self.http_session:
+ await self.http_session.close()
+
+ if self._connector:
+ await self._connector.close()
+
+ if self._resolver:
+ await self._resolver.close()
+
+ if getattr(self.stats, "_transport", False):
+ self.stats._transport.close()
+
+ if self._statsd_timerhandle:
+ self._statsd_timerhandle.cancel()
diff --git a/pydis_core/async_stats.py b/pydis_core/async_stats.py
new file mode 100644
index 00000000..411325e3
--- /dev/null
+++ b/pydis_core/async_stats.py
@@ -0,0 +1,57 @@
+"""An async transport method for statsd communication."""
+
+import asyncio
+import socket
+from typing import Optional
+
+from statsd.client.base import StatsClientBase
+
+from pydis_core.utils import scheduling
+
+
+class AsyncStatsClient(StatsClientBase):
+ """An async implementation of :obj:`statsd.client.base.StatsClientBase` that supports async stat communication."""
+
+ def __init__(
+ self,
+ loop: asyncio.AbstractEventLoop,
+ host: str = 'localhost',
+ port: int = 8125,
+ prefix: str = None
+ ):
+ """
+ Create a new :obj:`AsyncStatsClient`.
+
+ Args:
+ loop (asyncio.AbstractEventLoop): The event loop to use when creating the
+ :obj:`asyncio.loop.create_datagram_endpoint`.
+ host: The host to connect to.
+ port: The port to connect to.
+ prefix: The prefix to use for all stats.
+ """
+ _, _, _, _, addr = socket.getaddrinfo(
+ host, port, socket.AF_INET, socket.SOCK_DGRAM
+ )[0]
+ self._addr = addr
+ self._prefix = prefix
+ self._loop = loop
+ self._transport: Optional[asyncio.DatagramTransport] = None
+
+ async def create_socket(self) -> None:
+ """Use :obj:`asyncio.loop.create_datagram_endpoint` from the loop given on init to create a socket."""
+ self._transport, _ = await self._loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol,
+ family=socket.AF_INET,
+ remote_addr=self._addr
+ )
+
+ def _send(self, data: str) -> None:
+ """Start an async task to send data to statsd."""
+ scheduling.create_task(self._async_send(data), event_loop=self._loop)
+
+ async def _async_send(self, data: str) -> None:
+ """Send data to the statsd server using the async transport."""
+ self._transport.sendto(data.encode('ascii'), self._addr)
+
+
+__all__ = ['AsyncStatsClient']
diff --git a/pydis_core/exts/__init__.py b/pydis_core/exts/__init__.py
new file mode 100644
index 00000000..afd56166
--- /dev/null
+++ b/pydis_core/exts/__init__.py
@@ -0,0 +1,4 @@
+"""Reusable Discord cogs."""
+__all__ = []
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/pydis_core/site_api.py b/pydis_core/site_api.py
new file mode 100644
index 00000000..c17d2642
--- /dev/null
+++ b/pydis_core/site_api.py
@@ -0,0 +1,157 @@
+"""An API wrapper around the Site API."""
+
+import asyncio
+from typing import Optional
+from urllib.parse import quote as quote_url
+
+import aiohttp
+
+from pydis_core.utils.logging import get_logger
+
+log = get_logger(__name__)
+
+
+class ResponseCodeError(ValueError):
+ """Raised in :meth:`APIClient.request` when a non-OK HTTP response is received."""
+
+ def __init__(
+ self,
+ response: aiohttp.ClientResponse,
+ response_json: Optional[dict] = None,
+ response_text: Optional[str] = None
+ ):
+ """
+ Initialize a new :obj:`ResponseCodeError` instance.
+
+ Args:
+ response (:obj:`aiohttp.ClientResponse`): The response object from the request.
+ response_json: The JSON response returned from the request, if any.
+ response_text: The text of the request, if any.
+ """
+ self.status = response.status
+ self.response_json = response_json or {}
+ self.response_text = response_text
+ self.response = response
+
+ def __str__(self):
+ """Return a string representation of the error."""
+ response = self.response_json or self.response_text
+ return f"Status: {self.status} Response: {response}"
+
+
+class APIClient:
+ """A wrapper for the Django Site API."""
+
+ session: Optional[aiohttp.ClientSession] = None
+ loop: asyncio.AbstractEventLoop = None
+
+ def __init__(self, site_api_url: str, site_api_token: str, **session_kwargs):
+ """
+ Initialize a new :obj:`APIClient` instance.
+
+ Args:
+ site_api_url: The URL of the site API.
+ site_api_token: The token to use for authentication.
+ session_kwargs: Keyword arguments to pass to the :obj:`aiohttp.ClientSession` constructor.
+ """
+ self.site_api_url = site_api_url
+
+ auth_headers = {
+ 'Authorization': f"Token {site_api_token}"
+ }
+
+ if 'headers' in session_kwargs:
+ session_kwargs['headers'].update(auth_headers)
+ else:
+ session_kwargs['headers'] = auth_headers
+
+ # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we
+ # don't and shouldn't need to do that, so we can avoid scheduling a task to create it.
+ self.session = aiohttp.ClientSession(**session_kwargs)
+
+ def _url_for(self, endpoint: str) -> str:
+ return f"{self.site_api_url}/{quote_url(endpoint)}"
+
+ async def close(self) -> None:
+ """Close the aiohttp session."""
+ await self.session.close()
+
+ @staticmethod
+ async def maybe_raise_for_status(response: aiohttp.ClientResponse, should_raise: bool) -> None:
+ """
+ Raise :exc:`ResponseCodeError` for non-OK response if an exception should be raised.
+
+ Args:
+ response (:obj:`aiohttp.ClientResponse`): The response to check.
+ should_raise: Whether or not to raise an exception.
+
+ Raises:
+ :exc:`ResponseCodeError`:
+ If the response is not OK and ``should_raise`` is True.
+ """
+ if should_raise and response.status >= 400:
+ try:
+ response_json = await response.json()
+ raise ResponseCodeError(response=response, response_json=response_json)
+ except aiohttp.ContentTypeError:
+ response_text = await response.text()
+ raise ResponseCodeError(response=response, response_text=response_text)
+
+ async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """
+ Send an HTTP request to the site API and return the JSON response.
+
+ Args:
+ method: The HTTP method to use.
+ endpoint: The endpoint to send the request to.
+ raise_for_status: Whether or not to raise an exception if the response is not OK.
+ **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`.
+
+ Returns:
+ The JSON response the API returns.
+
+ Raises:
+ :exc:`ResponseCodeError`:
+ If the response is not OK and ``raise_for_status`` is True.
+ """
+ async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp:
+ await self.maybe_raise_for_status(resp, raise_for_status)
+ return await resp.json()
+
+ async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Equivalent to :meth:`APIClient.request` with GET passed as the method."""
+ return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs)
+
+ async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Equivalent to :meth:`APIClient.request` with PATCH passed as the method."""
+ return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs)
+
+ async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Equivalent to :meth:`APIClient.request` with POST passed as the method."""
+ return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs)
+
+ async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Equivalent to :meth:`APIClient.request` with PUT passed as the method."""
+ return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs)
+
+ async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
+ """
+ Send a DELETE request to the site API and return the JSON response.
+
+ Args:
+ endpoint: The endpoint to send the request to.
+ raise_for_status: Whether or not to raise an exception if the response is not OK.
+ **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`.
+
+ Returns:
+ The JSON response the API returns, or None if the response is 204 No Content.
+ """
+ async with self.session.delete(self._url_for(endpoint), **kwargs) as resp:
+ if resp.status == 204:
+ return None
+
+ await self.maybe_raise_for_status(resp, raise_for_status)
+ return await resp.json()
+
+
+__all__ = ['APIClient', 'ResponseCodeError']
diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py
new file mode 100644
index 00000000..0542231e
--- /dev/null
+++ b/pydis_core/utils/__init__.py
@@ -0,0 +1,50 @@
+"""Useful utilities and tools for Discord bot development."""
+
+from pydis_core.utils import (
+ _monkey_patches,
+ caching,
+ channel,
+ commands,
+ cooldown,
+ function,
+ interactions,
+ logging,
+ members,
+ regex,
+ scheduling,
+)
+from pydis_core.utils._extensions import unqualify
+
+
+def apply_monkey_patches() -> None:
+ """
+ Applies all common monkey patches for our bots.
+
+ Patches :obj:`discord.ext.commands.Command` and :obj:`discord.ext.commands.Group` to support root aliases.
+ A ``root_aliases`` keyword argument is added to these two objects, which is a sequence of alias names
+ that will act as top-level groups rather than being aliases of the command's group.
+
+ It's stored as an attribute also named ``root_aliases``
+
+ Patches discord's internal ``send_typing`` method so that it ignores 403 errors from Discord.
+ When under heavy load Discord has added a CloudFlare worker to this route, which causes 403 errors to be thrown.
+ """
+ _monkey_patches._apply_monkey_patches()
+
+
+__all__ = [
+ apply_monkey_patches,
+ caching,
+ channel,
+ commands,
+ cooldown,
+ function,
+ interactions,
+ logging,
+ members,
+ regex,
+ scheduling,
+ unqualify,
+]
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/pydis_core/utils/_extensions.py b/pydis_core/utils/_extensions.py
new file mode 100644
index 00000000..536a0715
--- /dev/null
+++ b/pydis_core/utils/_extensions.py
@@ -0,0 +1,57 @@
+"""Utilities for loading Discord extensions."""
+
+import importlib
+import inspect
+import pkgutil
+import types
+from typing import NoReturn
+
+
+def unqualify(name: str) -> str:
+ """
+ Return an unqualified name given a qualified module/package ``name``.
+
+ Args:
+ name: The module name to unqualify.
+
+ Returns:
+ The unqualified module name.
+ """
+ return name.rsplit(".", maxsplit=1)[-1]
+
+
+def ignore_module(module: pkgutil.ModuleInfo) -> bool:
+ """Return whether the module with name `name` should be ignored."""
+ return any(name.startswith("_") for name in module.name.split("."))
+
+
+def walk_extensions(module: types.ModuleType) -> frozenset[str]:
+ """
+ Return all extension names from the given module.
+
+ Args:
+ module (types.ModuleType): The module to look for extensions in.
+
+ Returns:
+ A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`.
+ """
+
+ def on_error(name: str) -> NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ modules = set()
+
+ for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error):
+ if ignore_module(module_info):
+ # Ignore modules/packages that have a name starting with an underscore anywhere in their trees.
+ continue
+
+ if module_info.ispkg:
+ imported = importlib.import_module(module_info.name)
+ if not inspect.isfunction(getattr(imported, "setup", None)):
+ # If it lacks a setup function, it's not an extension.
+ continue
+
+ modules.add(module_info.name)
+
+ return frozenset(modules)
diff --git a/pydis_core/utils/_monkey_patches.py b/pydis_core/utils/_monkey_patches.py
new file mode 100644
index 00000000..f0a8dc9c
--- /dev/null
+++ b/pydis_core/utils/_monkey_patches.py
@@ -0,0 +1,73 @@
+"""Contains all common monkey patches, used to alter discord to fit our needs."""
+
+import logging
+import typing
+from datetime import datetime, timedelta
+from functools import partial, partialmethod
+
+from discord import Forbidden, http
+from discord.ext import commands
+
+log = logging.getLogger(__name__)
+
+
+class _Command(commands.Command):
+ """
+ A :obj:`discord.ext.commands.Command` subclass which supports root aliases.
+
+ A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as
+ top-level commands rather than being aliases of the command's group. It's stored as an attribute
+ also named ``root_aliases``.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.root_aliases = kwargs.get("root_aliases", [])
+
+ if not isinstance(self.root_aliases, (list, tuple)):
+ raise TypeError("Root aliases of a command must be a list or a tuple of strings.")
+
+
+class _Group(commands.Group, _Command):
+ """
+ A :obj:`discord.ext.commands.Group` subclass which supports root aliases.
+
+ A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as
+ top-level groups rather than being aliases of the command's group. It's stored as an attribute
+ also named ``root_aliases``.
+ """
+
+
+def _patch_typing() -> None:
+ """
+ Sometimes Discord turns off typing events by throwing 403s.
+
+ Handle those issues by patching discord's internal ``send_typing`` method so it ignores 403s in general.
+ """
+ log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!")
+
+ original = http.HTTPClient.send_typing
+ last_403: typing.Optional[datetime] = None
+
+ async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None:
+ nonlocal last_403
+ if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5):
+ log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.")
+ return
+ try:
+ await original(self, channel_id)
+ except Forbidden:
+ last_403 = datetime.utcnow()
+ log.warning("Got a 403 from typing event!")
+
+ http.HTTPClient.send_typing = honeybadger_type
+
+
+def _apply_monkey_patches() -> None:
+ """This is surfaced directly in pydis_core.utils.apply_monkey_patches()."""
+ commands.command = partial(commands.command, cls=_Command)
+ commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=_Command)
+
+ commands.group = partial(commands.group, cls=_Group)
+ commands.GroupMixin.group = partialmethod(commands.GroupMixin.group, cls=_Group)
+ _patch_typing()
diff --git a/pydis_core/utils/caching.py b/pydis_core/utils/caching.py
new file mode 100644
index 00000000..ac34bb9b
--- /dev/null
+++ b/pydis_core/utils/caching.py
@@ -0,0 +1,65 @@
+"""Utilities related to custom caches."""
+
+import functools
+import typing
+from collections import OrderedDict
+
+
+class AsyncCache:
+ """
+ LRU cache implementation for coroutines.
+
+ Once the cache exceeds the maximum size, keys are deleted in FIFO order.
+
+ An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key.
+ """
+
+ def __init__(self, max_size: int = 128):
+ """
+ Initialise a new :obj:`AsyncCache` instance.
+
+ Args:
+ max_size: How many items to store in the cache.
+ """
+ self._cache = OrderedDict()
+ self._max_size = max_size
+
+ def __call__(self, arg_offset: int = 0) -> typing.Callable:
+ """
+ Decorator for async cache.
+
+ Args:
+ arg_offset: The offset for the position of the key argument.
+
+ Returns:
+ A decorator to wrap the target function.
+ """
+
+ def decorator(function: typing.Callable) -> typing.Callable:
+ """
+ Define the async cache decorator.
+
+ Args:
+ function: The function to wrap.
+
+ Returns:
+ The wrapped function.
+ """
+
+ @functools.wraps(function)
+ async def wrapper(*args) -> typing.Any:
+ """Decorator wrapper for the caching logic."""
+ key = args[arg_offset:]
+
+ if key not in self._cache:
+ if len(self._cache) > self._max_size:
+ self._cache.popitem(last=False)
+
+ self._cache[key] = await function(*args)
+ return self._cache[key]
+ return wrapper
+ return decorator
+
+ def clear(self) -> None:
+ """Clear cache instance."""
+ self._cache.clear()
diff --git a/pydis_core/utils/channel.py b/pydis_core/utils/channel.py
new file mode 100644
index 00000000..854c64fd
--- /dev/null
+++ b/pydis_core/utils/channel.py
@@ -0,0 +1,54 @@
+"""Useful helper functions for interacting with various discord channel objects."""
+
+import discord
+from discord.ext.commands import Bot
+
+from pydis_core.utils import logging
+
+log = logging.get_logger(__name__)
+
+
+def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
+ """
+ Return whether the given ``channel`` in the the category with the id ``category_id``.
+
+ Args:
+ channel: The channel to check.
+ category_id: The category to check for.
+
+ Returns:
+ A bool depending on whether the channel is in the category.
+ """
+ return getattr(channel, "category_id", None) == category_id
+
+
+async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel:
+ """
+ Attempt to get or fetch the given ``channel_id`` from the bots cache, and return it.
+
+ Args:
+ bot: The :obj:`discord.ext.commands.Bot` instance to use for getting/fetching.
+ channel_id: The channel to get/fetch.
+
+ Raises:
+ :exc:`discord.InvalidData`
+ An unknown channel type was received from Discord.
+ :exc:`discord.HTTPException`
+ Retrieving the channel failed.
+ :exc:`discord.NotFound`
+ Invalid Channel ID.
+ :exc:`discord.Forbidden`
+ You do not have permission to fetch this channel.
+
+ Returns:
+ The channel from the ID.
+ """
+ log.trace(f"Getting the channel {channel_id}.")
+
+ channel = bot.get_channel(channel_id)
+ if not channel:
+ log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
+ channel = await bot.fetch_channel(channel_id)
+
+ log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
+ return channel
diff --git a/pydis_core/utils/commands.py b/pydis_core/utils/commands.py
new file mode 100644
index 00000000..7afd8137
--- /dev/null
+++ b/pydis_core/utils/commands.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+from discord import Message
+from discord.ext.commands import BadArgument, Context, clean_content
+
+
+async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str:
+ """
+ Cleans a text argument or replied message's content.
+
+ Args:
+ ctx: The command's context
+ text: The provided text argument of the command (if given)
+
+ Raises:
+ :exc:`discord.ext.commands.BadArgument`
+ `text` wasn't provided and there's no reply message / reply message content.
+
+ Returns:
+ The cleaned version of `text`, if given, else replied message.
+ """
+ clean_content_converter = clean_content(fix_channel_mentions=True)
+
+ if text:
+ return await clean_content_converter.convert(ctx, text)
+
+ if (
+ (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference
+ and isinstance(replied_message, Message) # referenced message hasn't been deleted
+ ):
+ if not (content := ctx.message.reference.resolved.content):
+ # The referenced message doesn't have a content (e.g. embed/image), so raise error
+ raise BadArgument("The referenced message doesn't have a text content.")
+
+ return await clean_content_converter.convert(ctx, content)
+
+ # No text provided, and either no message was referenced or we can't access the content
+ raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.")
diff --git a/pydis_core/utils/cooldown.py b/pydis_core/utils/cooldown.py
new file mode 100644
index 00000000..5129befd
--- /dev/null
+++ b/pydis_core/utils/cooldown.py
@@ -0,0 +1,220 @@
+"""Helpers for setting a cooldown on commands."""
+
+from __future__ import annotations
+
+import asyncio
+import random
+import time
+import typing
+import weakref
+from collections.abc import Awaitable, Callable, Hashable, Iterable
+from contextlib import suppress
+from dataclasses import dataclass
+
+import discord
+from discord.ext.commands import CommandError, Context
+
+from pydis_core.utils import scheduling
+from pydis_core.utils.function import command_wraps
+
+__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"]
+
+_KEYWORD_SEP_SENTINEL = object()
+
+_ArgsList = list[object]
+_HashableArgsTuple = tuple[Hashable, ...]
+
+if typing.TYPE_CHECKING:
+ import typing_extensions
+ from pydis_core import BotBase
+
+P = typing.ParamSpec("P")
+"""The command's signature."""
+R = typing.TypeVar("R")
+"""The command's return value."""
+
+
+class CommandOnCooldown(CommandError, typing.Generic[P, R]):
+ """Raised when a command is invoked while on cooldown."""
+
+ def __init__(
+ self,
+ message: str | None,
+ function: Callable[P, Awaitable[R]],
+ /,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ):
+ super().__init__(message, function, args, kwargs)
+ self._function = function
+ self._args = args
+ self._kwargs = kwargs
+
+ async def call_without_cooldown(self) -> R:
+ """
+ Run the command this cooldown blocked.
+
+ Returns:
+ The command's return value.
+ """
+ return await self._function(*self._args, **self._kwargs)
+
+
+@dataclass
+class _CooldownItem:
+ non_hashable_arguments: _ArgsList
+ timeout_timestamp: float
+
+
+@dataclass
+class _SeparatedArguments:
+ """Arguments separated into their hashable and non-hashable parts."""
+
+ hashable: _HashableArgsTuple
+ non_hashable: _ArgsList
+
+ @classmethod
+ def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self:
+ """Create a new instance from full call arguments."""
+ hashable = list[Hashable]()
+ non_hashable = list[object]()
+
+ for item in call_arguments:
+ try:
+ hash(item)
+ except TypeError:
+ non_hashable.append(item)
+ else:
+ hashable.append(item)
+
+ return cls(tuple(hashable), non_hashable)
+
+
+class _CommandCooldownManager:
+ """
+ Manage invocation cooldowns for a command through the arguments the command is called with.
+
+ Use `set_cooldown` to set a cooldown,
+ and `is_on_cooldown` to check for a cooldown for a channel with the given arguments.
+ A cooldown lasts for `cooldown_duration` seconds.
+ """
+
+ def __init__(self, *, cooldown_duration: float):
+ self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]()
+ self._cooldown_duration = cooldown_duration
+ self.cleanup_task = scheduling.create_task(
+ self._periodical_cleanup(random.uniform(0, 10)),
+ name="CooldownManager cleanup",
+ )
+ weakref.finalize(self, self.cleanup_task.cancel)
+
+ def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None:
+ """Set `call_arguments` arguments on cooldown in `channel`."""
+ timeout_timestamp = time.monotonic() + self._cooldown_duration
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.setdefault(
+ (channel, separated_arguments.hashable),
+ [],
+ )
+
+ for item in cooldowns_list:
+ if item.non_hashable_arguments == separated_arguments.non_hashable:
+ item.timeout_timestamp = timeout_timestamp
+ return
+
+ cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp))
+
+ def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool:
+ """Check whether `call_arguments` is on cooldown in `channel`."""
+ current_time = time.monotonic()
+ separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
+ cooldowns_list = self._cooldowns.get(
+ (channel, separated_arguments.hashable),
+ [],
+ )
+
+ for item in cooldowns_list:
+ if item.non_hashable_arguments == separated_arguments.non_hashable:
+ return item.timeout_timestamp > current_time
+ return False
+
+ async def _periodical_cleanup(self, initial_delay: float) -> None:
+ """
+ Delete stale items every hour after waiting for `initial_delay`.
+
+ The `initial_delay` ensures cleanups are not running for every command at the same time.
+ A strong reference to self is only kept while cleanup is running.
+ """
+ weak_self = weakref.ref(self)
+ del self
+
+ await asyncio.sleep(initial_delay)
+ while True:
+ await asyncio.sleep(60 * 60)
+ weak_self()._delete_stale_items()
+
+ def _delete_stale_items(self) -> None:
+ """Remove expired items from internal collections."""
+ current_time = time.monotonic()
+
+ for key, cooldowns_list in self._cooldowns.copy().items():
+ filtered_cooldowns = [
+ cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time
+ ]
+
+ if not filtered_cooldowns:
+ del self._cooldowns[key]
+ else:
+ self._cooldowns[key] = filtered_cooldowns
+
+
+def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]:
+ return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items())
+
+
+def block_duplicate_invocations(
+ *,
+ cooldown_duration: float = 5,
+ send_notice: bool = False,
+ args_preprocessor: Callable[P, Iterable[object]] | None = None,
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+ """
+ Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds.
+
+ Args:
+ cooldown_duration: Length of the cooldown in seconds.
+ send_notice: If :obj:`True`, notify the user about the cooldown with a reply.
+ args_preprocessor: If specified, this function is called with the args and kwargs the function is called with,
+ its return value is then used to check for the cooldown instead of the raw arguments.
+
+ Returns:
+ A decorator that adds a wrapper which applies the cooldowns.
+
+ Warning:
+ The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown.
+ """
+
+ def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
+ mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration)
+
+ @command_wraps(func)
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ if args_preprocessor is not None:
+ all_args = args_preprocessor(*args, **kwargs)
+ else:
+ all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command
+ ctx = typing.cast("Context[BotBase]", args[1])
+
+ if not isinstance(ctx.channel, discord.DMChannel):
+ if mgr.is_on_cooldown(ctx.channel, all_args):
+ if send_notice:
+ with suppress(discord.NotFound):
+ await ctx.reply("The command is on cooldown with the given arguments.")
+ raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs)
+ mgr.set_cooldown(ctx.channel, all_args)
+
+ return await func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py
new file mode 100644
index 00000000..d89163ec
--- /dev/null
+++ b/pydis_core/utils/function.py
@@ -0,0 +1,111 @@
+"""Utils for manipulating functions."""
+
+from __future__ import annotations
+
+import functools
+import types
+import typing
+from collections.abc import Callable, Sequence, Set
+
+__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]
+
+
+if typing.TYPE_CHECKING:
+ _P = typing.ParamSpec("_P")
+ _R = typing.TypeVar("_R")
+
+
+class GlobalNameConflictError(Exception):
+ """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
+
+
+def update_wrapper_globals(
+ wrapper: Callable[_P, _R],
+ wrapped: Callable[_P, _R],
+ *,
+ ignored_conflict_names: Set[str] = frozenset(),
+) -> Callable[_P, _R]:
+ r"""
+ Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals.
+
+ For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function
+ to resolve their values. This breaks for decorators that replace the function because they have
+ their own globals.
+
+ .. warning::
+ This function captures the state of ``wrapped``\'s module's globals when it's called;
+ changes won't be reflected in the new function's globals.
+
+ Args:
+ wrapper: The function to wrap.
+ wrapped: The function to wrap with.
+ ignored_conflict_names: A set of names to ignore if a conflict between them is found.
+
+ Raises:
+ :exc:`GlobalNameConflictError`:
+ If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints,
+ and is not in ``ignored_conflict_names``.
+ """
+ wrapped = typing.cast(types.FunctionType, wrapped)
+ wrapper = typing.cast(types.FunctionType, wrapper)
+
+ annotation_global_names = (
+ ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
+ )
+ # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
+ shared_globals = (
+ set(wrapper.__code__.co_names)
+ & set(annotation_global_names)
+ & set(wrapped.__globals__)
+ & set(wrapper.__globals__)
+ - ignored_conflict_names
+ )
+ if shared_globals:
+ raise GlobalNameConflictError(
+ f"wrapper and the wrapped function share the following "
+ f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
+ f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
+ )
+
+ new_globals = wrapper.__globals__.copy()
+ new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
+ return types.FunctionType(
+ code=wrapper.__code__,
+ globals=new_globals,
+ name=wrapper.__name__,
+ argdefs=wrapper.__defaults__,
+ closure=wrapper.__closure__,
+ )
+
+
+def command_wraps(
+ wrapped: Callable[_P, _R],
+ assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
+ updated: Sequence[str] = functools.WRAPPER_UPDATES,
+ *,
+ ignored_conflict_names: Set[str] = frozenset(),
+) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
+ r"""
+ Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation.
+
+ See :func:`update_wrapper_globals` for more details on how the globals are updated.
+
+ Args:
+ wrapped: The function to wrap with.
+ assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``.
+ updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``.
+ ignored_conflict_names: A set of names to ignore if a conflict between them is found.
+
+ Returns:
+ A decorator that behaves like :func:`functools.wraps`,
+ with the wrapper replaced with the function :func:`update_wrapper_globals` returned.
+ """ # noqa: D200
+ def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]:
+ return functools.update_wrapper(
+ update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
+ wrapped,
+ assigned,
+ updated,
+ )
+
+ return decorator
diff --git a/pydis_core/utils/interactions.py b/pydis_core/utils/interactions.py
new file mode 100644
index 00000000..3e4acffe
--- /dev/null
+++ b/pydis_core/utils/interactions.py
@@ -0,0 +1,98 @@
+import contextlib
+from typing import Optional, Sequence
+
+from discord import ButtonStyle, Interaction, Message, NotFound, ui
+
+from pydis_core.utils.logging import get_logger
+
+log = get_logger(__name__)
+
+
+class ViewWithUserAndRoleCheck(ui.View):
+ """
+ A view that allows the original invoker and moderators to interact with it.
+
+ Args:
+ allowed_users: A sequence of user's ids who are allowed to interact with the view.
+ allowed_roles: A sequence of role ids that are allowed to interact with the view.
+ timeout: Timeout in seconds from last interaction with the UI before no longer accepting input.
+ If ``None`` then there is no timeout.
+ message: The message to remove the view from on timeout. This can also be set with
+ ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated.
+ """
+
+ def __init__(
+ self,
+ *,
+ allowed_users: Sequence[int],
+ allowed_roles: Sequence[int],
+ timeout: Optional[float] = 180.0,
+ message: Optional[Message] = None
+ ) -> None:
+ super().__init__(timeout=timeout)
+ self.allowed_users = allowed_users
+ self.allowed_roles = allowed_roles
+ self.message = message
+
+ async def interaction_check(self, interaction: Interaction) -> bool:
+ """
+ Ensure the user clicking the button is the view invoker, or a moderator.
+
+ Args:
+ interaction: The interaction that occurred.
+ """
+ if interaction.user.id in self.allowed_users:
+ log.trace(
+ "Allowed interaction by %s (%d) on %d as they are an allowed user.",
+ interaction.user,
+ interaction.user.id,
+ interaction.message.id,
+ )
+ return True
+
+ if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])):
+ log.trace(
+ "Allowed interaction by %s (%d)on %d as they have an allowed role.",
+ interaction.user,
+ interaction.user.id,
+ interaction.message.id,
+ )
+ return True
+
+ await interaction.response.send_message("This is not your button to click!", ephemeral=True)
+ return False
+
+ async def on_timeout(self) -> None:
+ """Remove the view from ``self.message`` if set."""
+ if self.message:
+ with contextlib.suppress(NotFound):
+ # Cover the case where this message has already been deleted by external means
+ await self.message.edit(view=None)
+
+
+class DeleteMessageButton(ui.Button):
+ """
+ A button that can be added to a view to delete the message containing the view on click.
+
+ This button itself carries out no interaction checks, these should be done by the parent view.
+
+ See :obj:`pydis_core.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks.
+
+ Args:
+ style (:literal-url:`ButtonStyle <https://discordpy.readthedocs.io/en/latest/interactions/api.html#discord.ButtonStyle>`):
+ The style of the button, set to ``ButtonStyle.secondary`` if not specified.
+ label: The label of the button, set to "Delete" if not specified.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ *,
+ style: ButtonStyle = ButtonStyle.secondary,
+ label: str = "Delete",
+ **kwargs
+ ):
+ super().__init__(style=style, label=label, **kwargs)
+
+ async def callback(self, interaction: Interaction) -> None:
+ """Delete the original message on button click."""
+ await interaction.message.delete()
diff --git a/pydis_core/utils/logging.py b/pydis_core/utils/logging.py
new file mode 100644
index 00000000..7814f348
--- /dev/null
+++ b/pydis_core/utils/logging.py
@@ -0,0 +1,51 @@
+"""Common logging related functions."""
+
+import logging
+import typing
+
+if typing.TYPE_CHECKING:
+ LoggerClass = logging.Logger
+else:
+ LoggerClass = logging.getLoggerClass()
+
+TRACE_LEVEL = 5
+
+
+class CustomLogger(LoggerClass):
+ """Custom implementation of the :obj:`logging.Logger` class with an added :obj:`trace` method."""
+
+ def trace(self, msg: str, *args, **kwargs) -> None:
+ """
+ Log the given message with the severity ``"TRACE"``.
+
+ To pass exception information, use the keyword argument exc_info with a true value:
+
+ .. code-block:: py
+
+ logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
+
+ Args:
+ msg: The message to be logged.
+ args, kwargs: Passed to the base log function as is.
+ """
+ if self.isEnabledFor(TRACE_LEVEL):
+ self.log(TRACE_LEVEL, msg, *args, **kwargs)
+
+
+def get_logger(name: typing.Optional[str] = None) -> CustomLogger:
+ """
+ Utility to make mypy recognise that logger is of type :obj:`CustomLogger`.
+
+ Args:
+ name: The name given to the logger.
+
+ Returns:
+ An instance of the :obj:`CustomLogger` class.
+ """
+ return typing.cast(CustomLogger, logging.getLogger(name))
+
+
+# Setup trace level logging so that we can use it within pydis_core.
+logging.TRACE = TRACE_LEVEL
+logging.setLoggerClass(CustomLogger)
+logging.addLevelName(TRACE_LEVEL, "TRACE")
diff --git a/pydis_core/utils/members.py b/pydis_core/utils/members.py
new file mode 100644
index 00000000..b6eacc88
--- /dev/null
+++ b/pydis_core/utils/members.py
@@ -0,0 +1,57 @@
+"""Useful helper functions for interactin with :obj:`discord.Member` objects."""
+import typing
+from collections import abc
+
+import discord
+
+from pydis_core.utils import logging
+
+log = logging.get_logger(__name__)
+
+
+async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]:
+ """
+ Attempt to get a member from cache; on failure fetch from the API.
+
+ Returns:
+ The :obj:`discord.Member` or :obj:`None` to indicate the member could not be found.
+ """
+ if member := guild.get_member(member_id):
+ log.trace(f"{member} retrieved from cache.")
+ else:
+ try:
+ member = await guild.fetch_member(member_id)
+ except discord.errors.NotFound:
+ log.trace(f"Failed to fetch {member_id} from API.")
+ return None
+ log.trace(f"{member} fetched from API.")
+ return member
+
+
+async def handle_role_change(
+ member: discord.Member,
+ coro: typing.Callable[[discord.Role], abc.Coroutine],
+ role: discord.Role
+) -> None:
+ """
+ Await the given ``coro`` with ``role`` as the sole argument.
+
+ Handle errors that we expect to be raised from
+ :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`.
+
+ Args:
+ member: The member that is being modified for logging purposes.
+ coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`.
+ role: The role to be passed to ``coro``.
+ """
+ try:
+ await coro(role)
+ except discord.NotFound:
+ log.error(f"Failed to change role for {member} ({member.id}): member not found")
+ except discord.Forbidden:
+ log.error(
+ f"Forbidden to change role for {member} ({member.id}); "
+ f"possibly due to role hierarchy"
+ )
+ except discord.HTTPException as e:
+ log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}")
diff --git a/pydis_core/utils/regex.py b/pydis_core/utils/regex.py
new file mode 100644
index 00000000..de82a1ed
--- /dev/null
+++ b/pydis_core/utils/regex.py
@@ -0,0 +1,54 @@
+"""Common regular expressions."""
+
+import re
+
+DISCORD_INVITE = re.compile(
+ r"(https?://)?(www\.)?" # Optional http(s) and www.
+ r"(discord([.,]|dot)gg|" # Could be discord.gg/
+ r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/
+ r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/
+ r"discord([.,]|dot)me|" # or discord.me
+ r"discord([.,]|dot)li|" # or discord.li
+ r"discord([.,]|dot)io|" # or discord.io.
+ r"((?<!\w)([.,]|dot))gg" # or .gg/
+ r")(/|slash)" # / or 'slash'
+ r"(?P<invite>\S+)", # the invite code itself
+ flags=re.IGNORECASE
+)
+"""
+Regex for Discord server invites.
+
+.. warning::
+ This regex pattern will capture until a whitespace, if you are to use the 'invite' capture group in
+ any HTTP requests or similar. Please ensure you sanitise the output using something
+ such as :func:`urllib.parse.quote`.
+
+:meta hide-value:
+"""
+
+FORMATTED_CODE_REGEX = re.compile(
+ r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
+ r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
+ r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all code inside the markup
+ r"\s*" # any more whitespace before the end of the code markup
+ r"(?P=delim)", # match the exact same delimiter from the start again
+ flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
+)
+"""
+Regex for formatted code, using Discord's code blocks.
+
+:meta hide-value:
+"""
+
+RAW_CODE_REGEX = re.compile(
+ r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all the rest as code
+ r"\s*$", # any trailing whitespace until the end of the string
+ flags=re.DOTALL # "." also matches newlines
+)
+"""
+Regex for raw code, *not* using Discord's code blocks.
+
+:meta hide-value:
+"""
diff --git a/pydis_core/utils/scheduling.py b/pydis_core/utils/scheduling.py
new file mode 100644
index 00000000..eced4a3d
--- /dev/null
+++ b/pydis_core/utils/scheduling.py
@@ -0,0 +1,252 @@
+"""Generic python scheduler."""
+
+import asyncio
+import contextlib
+import inspect
+import typing
+from collections import abc
+from datetime import datetime
+from functools import partial
+
+from pydis_core.utils import logging
+
+
+class Scheduler:
+ """
+ Schedule the execution of coroutines and keep track of them.
+
+ When instantiating a :obj:`Scheduler`, a name must be provided. This name is used to distinguish the
+ instance's log messages from other instances. Using the name of the class or module containing
+ the instance is suggested.
+
+ Coroutines can be scheduled immediately with :obj:`schedule` or in the future with :obj:`schedule_at`
+ or :obj:`schedule_later`. A unique ID is required to be given in order to keep track of the
+ resulting Tasks. Any scheduled task can be cancelled prematurely using :obj:`cancel` by providing
+ the same ID used to schedule it.
+
+ The ``in`` operator is supported for checking if a task with a given ID is currently scheduled.
+
+ Any exception raised in a scheduled task is logged when the task is done.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize a new :obj:`Scheduler` instance.
+
+ Args:
+ name: The name of the :obj:`Scheduler`. Used in logging, and namespacing.
+ """
+ self.name = name
+
+ self._log = logging.get_logger(f"{__name__}.{name}")
+ self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {}
+
+ def __contains__(self, task_id: abc.Hashable) -> bool:
+ """
+ Return :obj:`True` if a task with the given ``task_id`` is currently scheduled.
+
+ Args:
+ task_id: The task to look for.
+
+ Returns:
+ :obj:`True` if the task was found.
+ """
+ return task_id in self._scheduled_tasks
+
+ def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:
+ """
+ Schedule the execution of a ``coroutine``.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self._log.trace(f"Scheduling task #{task_id}...")
+
+ msg = f"Cannot schedule an already started coroutine for #{task_id}"
+ assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg
+
+ if task_id in self._scheduled_tasks:
+ self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.")
+ coroutine.close()
+ return
+
+ task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}")
+ task.add_done_callback(partial(self._task_done_callback, task_id))
+
+ self._scheduled_tasks[task_id] = task
+ self._log.debug(f"Scheduled task #{task_id} {id(task)}.")
+
+ def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:
+ """
+ Schedule ``coroutine`` to be executed at the given ``time``.
+
+ If ``time`` is timezone aware, then use that timezone to calculate now() when subtracting.
+ If ``time`` is naïve, then use UTC.
+
+ If ``time`` is in the past, schedule ``coroutine`` immediately.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ time: The time to start the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow()
+ delay = (time - now_datetime).total_seconds()
+ if delay > 0:
+ coroutine = self._await_later(delay, task_id, coroutine)
+
+ self.schedule(task_id, coroutine)
+
+ def schedule_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: abc.Hashable,
+ coroutine: abc.Coroutine
+ ) -> None:
+ """
+ Schedule ``coroutine`` to be executed after ``delay`` seconds.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ delay: How long to wait before starting the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self.schedule(task_id, self._await_later(delay, task_id, coroutine))
+
+ def cancel(self, task_id: abc.Hashable) -> None:
+ """
+ Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist.
+
+ Args:
+ task_id: The task's unique ID.
+ """
+ self._log.trace(f"Cancelling task #{task_id}...")
+
+ try:
+ task = self._scheduled_tasks.pop(task_id)
+ except KeyError:
+ self._log.warning(f"Failed to unschedule {task_id} (no task found).")
+ else:
+ task.cancel()
+
+ self._log.debug(f"Unscheduled task #{task_id} {id(task)}.")
+
+ def cancel_all(self) -> None:
+ """Unschedule all known tasks."""
+ self._log.debug("Unscheduling all tasks")
+
+ for task_id in self._scheduled_tasks.copy():
+ self.cancel(task_id)
+
+ async def _await_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: abc.Hashable,
+ coroutine: abc.Coroutine
+ ) -> None:
+ """Await ``coroutine`` after ``delay`` seconds."""
+ try:
+ self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.")
+ await asyncio.sleep(delay)
+
+ # Use asyncio.shield to prevent the coroutine from cancelling itself.
+ self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.")
+ await asyncio.shield(coroutine)
+ finally:
+ # Close it to prevent unawaited coroutine warnings,
+ # which would happen if the task was cancelled during the sleep.
+ # Only close it if it's not been awaited yet. This check is important because the
+ # coroutine may cancel this task, which would also trigger the finally block.
+ state = inspect.getcoroutinestate(coroutine)
+ if state == "CORO_CREATED":
+ self._log.debug(f"Explicitly closing the coroutine for #{task_id}.")
+ coroutine.close()
+ else:
+ self._log.debug(f"Finally block reached for #{task_id}; {state=}")
+
+ def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None:
+ """
+ Delete the task and raise its exception if one exists.
+
+ If ``done_task`` and the task associated with ``task_id`` are different, then the latter
+ will not be deleted. In this case, a new task was likely rescheduled with the same ID.
+ """
+ self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.")
+
+ scheduled_task = self._scheduled_tasks.get(task_id)
+
+ if scheduled_task and done_task is scheduled_task:
+ # A task for the ID exists and is the same as the done task.
+ # Since this is the done callback, the task is already done so no need to cancel it.
+ self._log.trace(f"Deleting task #{task_id} {id(done_task)}.")
+ del self._scheduled_tasks[task_id]
+ elif scheduled_task:
+ # A new task was likely rescheduled with the same ID.
+ self._log.debug(
+ f"The scheduled task #{task_id} {id(scheduled_task)} "
+ f"and the done task {id(done_task)} differ."
+ )
+ elif not done_task.cancelled():
+ self._log.warning(
+ f"Task #{task_id} not found while handling task {id(done_task)}! "
+ f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."
+ )
+
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = done_task.exception()
+ # Log the exception if one exists.
+ if exception:
+ self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
+
+
+TASK_RETURN = typing.TypeVar("TASK_RETURN")
+
+
+def create_task(
+ coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN],
+ *,
+ suppressed_exceptions: tuple[type[Exception], ...] = (),
+ event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
+ **kwargs,
+) -> asyncio.Task[TASK_RETURN]:
+ """
+ Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task.
+
+ If the ``event_loop`` kwarg is provided, the task is created from that event loop,
+ otherwise the running loop is used.
+
+ Args:
+ coro: The function to call.
+ suppressed_exceptions: Exceptions to be handled by the task.
+ event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from.
+ kwargs: Passed to :py:func:`asyncio.create_task`.
+
+ Returns:
+ asyncio.Task: The wrapped task.
+ """
+ if event_loop is not None:
+ task = event_loop.create_task(coro, **kwargs)
+ else:
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
+ return task
+
+
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None:
+ """Retrieve and log the exception raised in ``task`` if one exists."""
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = task.exception()
+ # Log the exception if one exists.
+ if exception and not isinstance(exception, suppressed_exceptions):
+ log = logging.get_logger(__name__)
+ log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)