From 962968fecedca3bef33ba9524d87ffedf815f16d Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Sat, 5 Nov 2022 13:39:52 +0000 Subject: Rename package due to naming conflict --- .dockerignore | 2 +- .github/workflows/lint-test.yaml | 2 +- botcore/__init__.py | 15 -- botcore/_bot.py | 288 -------------------------------- botcore/async_stats.py | 57 ------- botcore/exts/__init__.py | 4 - botcore/site_api.py | 157 ----------------- botcore/utils/__init__.py | 50 ------ botcore/utils/_extensions.py | 57 ------- botcore/utils/_monkey_patches.py | 73 -------- botcore/utils/caching.py | 65 ------- botcore/utils/channel.py | 54 ------ botcore/utils/commands.py | 38 ----- botcore/utils/cooldown.py | 220 ------------------------ botcore/utils/function.py | 111 ------------ botcore/utils/interactions.py | 98 ----------- botcore/utils/logging.py | 51 ------ botcore/utils/members.py | 57 ------- botcore/utils/regex.py | 54 ------ botcore/utils/scheduling.py | 252 ---------------------------- dev/README.rst | 6 +- dev/bot/__init__.py | 4 +- dev/bot/__main__.py | 4 +- docker-compose.yaml | 2 +- docs/changelog.rst | 26 +-- docs/index.rst | 2 +- docs/utils.py | 12 +- pydis_core/__init__.py | 15 ++ pydis_core/_bot.py | 288 ++++++++++++++++++++++++++++++++ pydis_core/async_stats.py | 57 +++++++ pydis_core/exts/__init__.py | 4 + pydis_core/site_api.py | 157 +++++++++++++++++ pydis_core/utils/__init__.py | 50 ++++++ pydis_core/utils/_extensions.py | 57 +++++++ pydis_core/utils/_monkey_patches.py | 73 ++++++++ pydis_core/utils/caching.py | 65 +++++++ pydis_core/utils/channel.py | 54 ++++++ pydis_core/utils/commands.py | 38 +++++ pydis_core/utils/cooldown.py | 220 ++++++++++++++++++++++++ pydis_core/utils/function.py | 111 ++++++++++++ pydis_core/utils/interactions.py | 98 +++++++++++ pydis_core/utils/logging.py | 51 ++++++ pydis_core/utils/members.py | 57 +++++++ pydis_core/utils/regex.py | 54 ++++++ pydis_core/utils/scheduling.py | 252 ++++++++++++++++++++++++++++ pyproject.toml | 11 +- tests/botcore/test_api.py | 69 -------- tests/botcore/utils/test_cooldown.py | 49 ------ tests/botcore/utils/test_regex.py | 65 ------- tests/pydis_core/test_api.py | 69 ++++++++ tests/pydis_core/utils/test_cooldown.py | 49 ++++++ tests/pydis_core/utils/test_regex.py | 65 +++++++ tox.ini | 2 +- 53 files changed, 1923 insertions(+), 1918 deletions(-) delete mode 100644 botcore/__init__.py delete mode 100644 botcore/_bot.py delete mode 100644 botcore/async_stats.py delete mode 100644 botcore/exts/__init__.py delete mode 100644 botcore/site_api.py delete mode 100644 botcore/utils/__init__.py delete mode 100644 botcore/utils/_extensions.py delete mode 100644 botcore/utils/_monkey_patches.py delete mode 100644 botcore/utils/caching.py delete mode 100644 botcore/utils/channel.py delete mode 100644 botcore/utils/commands.py delete mode 100644 botcore/utils/cooldown.py delete mode 100644 botcore/utils/function.py delete mode 100644 botcore/utils/interactions.py delete mode 100644 botcore/utils/logging.py delete mode 100644 botcore/utils/members.py delete mode 100644 botcore/utils/regex.py delete mode 100644 botcore/utils/scheduling.py create mode 100644 pydis_core/__init__.py create mode 100644 pydis_core/_bot.py create mode 100644 pydis_core/async_stats.py create mode 100644 pydis_core/exts/__init__.py create mode 100644 pydis_core/site_api.py create mode 100644 pydis_core/utils/__init__.py create mode 100644 pydis_core/utils/_extensions.py create mode 100644 pydis_core/utils/_monkey_patches.py create mode 100644 pydis_core/utils/caching.py create mode 100644 pydis_core/utils/channel.py create mode 100644 pydis_core/utils/commands.py create mode 100644 pydis_core/utils/cooldown.py create mode 100644 pydis_core/utils/function.py create mode 100644 pydis_core/utils/interactions.py create mode 100644 pydis_core/utils/logging.py create mode 100644 pydis_core/utils/members.py create mode 100644 pydis_core/utils/regex.py create mode 100644 pydis_core/utils/scheduling.py delete mode 100644 tests/botcore/test_api.py delete mode 100644 tests/botcore/utils/test_cooldown.py delete mode 100644 tests/botcore/utils/test_regex.py create mode 100644 tests/pydis_core/test_api.py create mode 100644 tests/pydis_core/utils/test_cooldown.py create mode 100644 tests/pydis_core/utils/test_regex.py diff --git a/.dockerignore b/.dockerignore index 9fb3df72..b36215c3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,6 @@ * -!botcore/ +!pydis_core/ !docs/ !tests/ diff --git a/.github/workflows/lint-test.yaml b/.github/workflows/lint-test.yaml index e9821677..dc83086b 100644 --- a/.github/workflows/lint-test.yaml +++ b/.github/workflows/lint-test.yaml @@ -41,7 +41,7 @@ jobs: --format='::error file=%(path)s,line=%(row)d,col=%(col)d::[flake8] %(code)s: %(text)s'" - name: Run tests and generate coverage report - run: python -m pytest -n auto --cov botcore -q + run: python -m pytest -n auto --cov pydis_core -q # Prepare the Pull Request Payload artifact. If this fails, we # we fail silently using the `continue-on-error` option. It's diff --git a/botcore/__init__.py b/botcore/__init__.py deleted file mode 100644 index f0c4e6bb..00000000 --- a/botcore/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Useful utilities and tools for Discord bot development.""" - -from botcore import async_stats, exts, site_api, utils -from botcore._bot import BotBase, StartupError - -__all__ = [ - async_stats, - BotBase, - exts, - utils, - site_api, - StartupError, -] - -__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/_bot.py b/botcore/_bot.py deleted file mode 100644 index bb25c0b5..00000000 --- a/botcore/_bot.py +++ /dev/null @@ -1,288 +0,0 @@ -import asyncio -import socket -import types -import warnings -from contextlib import suppress -from typing import Optional - -import aiohttp -import discord -from discord.ext import commands - -from botcore.async_stats import AsyncStatsClient -from botcore.site_api import APIClient -from botcore.utils import scheduling -from botcore.utils._extensions import walk_extensions -from botcore.utils.logging import get_logger - -try: - from async_rediscache import RedisSession - REDIS_AVAILABLE = True -except ImportError: - RedisSession = None - REDIS_AVAILABLE = False - -log = get_logger() - - -class StartupError(Exception): - """Exception class for startup errors.""" - - def __init__(self, base: Exception): - super().__init__() - self.exception = base - - -class BotBase(commands.Bot): - """A sub-class that implements many common features that Python Discord bots use.""" - - def __init__( - self, - *args, - guild_id: int, - allowed_roles: list, - http_session: aiohttp.ClientSession, - redis_session: Optional[RedisSession] = None, - api_client: Optional[APIClient] = None, - statsd_url: Optional[str] = None, - **kwargs, - ): - """ - Initialise the base bot instance. - - Args: - guild_id: The ID of the guild used for :func:`wait_until_guild_available`. - allowed_roles: A list of role IDs that the bot is allowed to mention. - http_session (aiohttp.ClientSession): The session to use for the bot. - redis_session: The `async_rediscache.RedisSession`_ to use for the bot. - api_client: The :obj:`botcore.site_api.APIClient` instance to use for the bot. - statsd_url: The URL of the statsd server to use for the bot. If not given, - a dummy statsd client will be created. - - .. _async_rediscache.RedisSession: https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession - """ - super().__init__( - *args, - allowed_roles=allowed_roles, - **kwargs, - ) - - self.guild_id = guild_id - self.http_session = http_session - self.api_client = api_client - self.statsd_url = statsd_url - - if redis_session and not REDIS_AVAILABLE: - warnings.warn("redis_session kwarg passed, but async-rediscache not installed!") - elif redis_session: - self.redis_session = redis_session - - self._resolver: Optional[aiohttp.AsyncResolver] = None - self._connector: Optional[aiohttp.TCPConnector] = None - - self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None - self._guild_available: Optional[asyncio.Event] = None - - self.stats: Optional[AsyncStatsClient] = None - - self.all_extensions: Optional[frozenset[str]] = None - - def _connect_statsd( - self, - statsd_url: str, - loop: asyncio.AbstractEventLoop, - retry_after: int = 2, - attempt: int = 1 - ) -> None: - """Callback used to retry a connection to statsd if it should fail.""" - if attempt >= 8: - log.error( - "Reached 8 attempts trying to reconnect AsyncStatsClient to %s. " - "Aborting and leaving the dummy statsd client in place.", - statsd_url, - ) - return - - try: - self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot") - except socket.gaierror: - log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})") - # Use a fallback strategy for retrying, up to 8 times. - self._statsd_timerhandle = loop.call_later( - retry_after, - self._connect_statsd, - statsd_url, - retry_after * 2, - attempt + 1 - ) - - async def load_extensions(self, module: types.ModuleType) -> None: - """ - Load all the extensions within the given module and save them to ``self.all_extensions``. - - This should be ran in a task on the event loop to avoid deadlocks caused by ``wait_for`` calls. - """ - await self.wait_until_guild_available() - self.all_extensions = walk_extensions(module) - - for extension in self.all_extensions: - scheduling.create_task(self.load_extension(extension)) - - def _add_root_aliases(self, command: commands.Command) -> None: - """Recursively add root aliases for ``command`` and any of its subcommands.""" - if isinstance(command, commands.Group): - for subcommand in command.commands: - self._add_root_aliases(subcommand) - - for alias in getattr(command, "root_aliases", ()): - if alias in self.all_commands: - raise commands.CommandRegistrationError(alias, alias_conflict=True) - - self.all_commands[alias] = command - - def _remove_root_aliases(self, command: commands.Command) -> None: - """Recursively remove root aliases for ``command`` and any of its subcommands.""" - if isinstance(command, commands.Group): - for subcommand in command.commands: - self._remove_root_aliases(subcommand) - - for alias in getattr(command, "root_aliases", ()): - self.all_commands.pop(alias, None) - - async def add_cog(self, cog: commands.Cog) -> None: - """Add the given ``cog`` to the bot and log the operation.""" - await super().add_cog(cog) - log.info(f"Cog loaded: {cog.qualified_name}") - - def add_command(self, command: commands.Command) -> None: - """Add ``command`` as normal and then add its root aliases to the bot.""" - super().add_command(command) - self._add_root_aliases(command) - - def remove_command(self, name: str) -> Optional[commands.Command]: - """ - Remove a command/alias as normal and then remove its root aliases from the bot. - - Individual root aliases cannot be removed by this function. - To remove them, either remove the entire command or manually edit `bot.all_commands`. - """ - command = super().remove_command(name) - if command is None: - # Even if it's a root alias, there's no way to get the Bot instance to remove the alias. - return None - - self._remove_root_aliases(command) - return command - - def clear(self) -> None: - """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one.""" - raise NotImplementedError("Re-using a Bot object after closing it is not supported.") - - async def on_guild_unavailable(self, guild: discord.Guild) -> None: - """Clear the internal guild available event when self.guild_id becomes unavailable.""" - if guild.id != self.guild_id: - return - - self._guild_available.clear() - - async def on_guild_available(self, guild: discord.Guild) -> None: - """ - Set the internal guild available event when self.guild_id becomes available. - - If the cache appears to still be empty (no members, no channels, or no roles), the event - will not be set and `guild_available_but_cache_empty` event will be emitted. - """ - if guild.id != self.guild_id: - return - - if not guild.roles or not guild.members or not guild.channels: - msg = "Guild available event was dispatched but the cache appears to still be empty!" - await self.log_to_dev_log(msg) - return - - self._guild_available.set() - - async def log_to_dev_log(self, message: str) -> None: - """Log the given message to #dev-log.""" - ... - - async def wait_until_guild_available(self) -> None: - """ - Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready). - - The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE - gateway event before giving up and thus not populating the cache for unavailable guilds. - """ - await self._guild_available.wait() - - async def setup_hook(self) -> None: - """ - An async init to startup generic services. - - Connects to statsd, and calls - :func:`AsyncStatsClient.create_socket ` - and :func:`ping_services`. - """ - loop = asyncio.get_running_loop() - - self._guild_available = asyncio.Event() - - self._resolver = aiohttp.AsyncResolver() - self._connector = aiohttp.TCPConnector( - resolver=self._resolver, - family=socket.AF_INET, - ) - self.http.connector = self._connector - - if getattr(self, "redis_session", False) and not self.redis_session.valid: - # If the RedisSession was somehow closed, we try to reconnect it - # here. Normally, this shouldn't happen. - await self.redis_session.connect(ping=True) - - # Create dummy stats client first, in case `statsd_url` is unreachable or None - self.stats = AsyncStatsClient(loop, "127.0.0.1") - if self.statsd_url: - self._connect_statsd(self.statsd_url, loop) - - await self.stats.create_socket() - - try: - await self.ping_services() - except Exception as e: - raise StartupError(e) - - async def ping_services(self) -> None: - """Ping all required services on setup to ensure they are up before starting.""" - ... - - async def close(self) -> None: - """Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver.""" - # Done before super().close() to allow tasks finish before the HTTP session closes. - for ext in list(self.extensions): - with suppress(Exception): - await self.unload_extension(ext) - - for cog in list(self.cogs): - with suppress(Exception): - await self.remove_cog(cog) - - # Now actually do full close of bot - await super().close() - - if self.api_client: - await self.api_client.close() - - if self.http_session: - await self.http_session.close() - - if self._connector: - await self._connector.close() - - if self._resolver: - await self._resolver.close() - - if getattr(self.stats, "_transport", False): - self.stats._transport.close() - - if self._statsd_timerhandle: - self._statsd_timerhandle.cancel() diff --git a/botcore/async_stats.py b/botcore/async_stats.py deleted file mode 100644 index fef5b2d6..00000000 --- a/botcore/async_stats.py +++ /dev/null @@ -1,57 +0,0 @@ -"""An async transport method for statsd communication.""" - -import asyncio -import socket -from typing import Optional - -from statsd.client.base import StatsClientBase - -from botcore.utils import scheduling - - -class AsyncStatsClient(StatsClientBase): - """An async implementation of :obj:`statsd.client.base.StatsClientBase` that supports async stat communication.""" - - def __init__( - self, - loop: asyncio.AbstractEventLoop, - host: str = 'localhost', - port: int = 8125, - prefix: str = None - ): - """ - Create a new :obj:`AsyncStatsClient`. - - Args: - loop (asyncio.AbstractEventLoop): The event loop to use when creating the - :obj:`asyncio.loop.create_datagram_endpoint`. - host: The host to connect to. - port: The port to connect to. - prefix: The prefix to use for all stats. - """ - _, _, _, _, addr = socket.getaddrinfo( - host, port, socket.AF_INET, socket.SOCK_DGRAM - )[0] - self._addr = addr - self._prefix = prefix - self._loop = loop - self._transport: Optional[asyncio.DatagramTransport] = None - - async def create_socket(self) -> None: - """Use :obj:`asyncio.loop.create_datagram_endpoint` from the loop given on init to create a socket.""" - self._transport, _ = await self._loop.create_datagram_endpoint( - asyncio.DatagramProtocol, - family=socket.AF_INET, - remote_addr=self._addr - ) - - def _send(self, data: str) -> None: - """Start an async task to send data to statsd.""" - scheduling.create_task(self._async_send(data), event_loop=self._loop) - - async def _async_send(self, data: str) -> None: - """Send data to the statsd server using the async transport.""" - self._transport.sendto(data.encode('ascii'), self._addr) - - -__all__ = ['AsyncStatsClient'] diff --git a/botcore/exts/__init__.py b/botcore/exts/__init__.py deleted file mode 100644 index afd56166..00000000 --- a/botcore/exts/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Reusable Discord cogs.""" -__all__ = [] - -__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/site_api.py b/botcore/site_api.py deleted file mode 100644 index 44309f9d..00000000 --- a/botcore/site_api.py +++ /dev/null @@ -1,157 +0,0 @@ -"""An API wrapper around the Site API.""" - -import asyncio -from typing import Optional -from urllib.parse import quote as quote_url - -import aiohttp - -from botcore.utils.logging import get_logger - -log = get_logger(__name__) - - -class ResponseCodeError(ValueError): - """Raised in :meth:`APIClient.request` when a non-OK HTTP response is received.""" - - def __init__( - self, - response: aiohttp.ClientResponse, - response_json: Optional[dict] = None, - response_text: Optional[str] = None - ): - """ - Initialize a new :obj:`ResponseCodeError` instance. - - Args: - response (:obj:`aiohttp.ClientResponse`): The response object from the request. - response_json: The JSON response returned from the request, if any. - response_text: The text of the request, if any. - """ - self.status = response.status - self.response_json = response_json or {} - self.response_text = response_text - self.response = response - - def __str__(self): - """Return a string representation of the error.""" - response = self.response_json or self.response_text - return f"Status: {self.status} Response: {response}" - - -class APIClient: - """A wrapper for the Django Site API.""" - - session: Optional[aiohttp.ClientSession] = None - loop: asyncio.AbstractEventLoop = None - - def __init__(self, site_api_url: str, site_api_token: str, **session_kwargs): - """ - Initialize a new :obj:`APIClient` instance. - - Args: - site_api_url: The URL of the site API. - site_api_token: The token to use for authentication. - session_kwargs: Keyword arguments to pass to the :obj:`aiohttp.ClientSession` constructor. - """ - self.site_api_url = site_api_url - - auth_headers = { - 'Authorization': f"Token {site_api_token}" - } - - if 'headers' in session_kwargs: - session_kwargs['headers'].update(auth_headers) - else: - session_kwargs['headers'] = auth_headers - - # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we - # don't and shouldn't need to do that, so we can avoid scheduling a task to create it. - self.session = aiohttp.ClientSession(**session_kwargs) - - def _url_for(self, endpoint: str) -> str: - return f"{self.site_api_url}/{quote_url(endpoint)}" - - async def close(self) -> None: - """Close the aiohttp session.""" - await self.session.close() - - @staticmethod - async def maybe_raise_for_status(response: aiohttp.ClientResponse, should_raise: bool) -> None: - """ - Raise :exc:`ResponseCodeError` for non-OK response if an exception should be raised. - - Args: - response (:obj:`aiohttp.ClientResponse`): The response to check. - should_raise: Whether or not to raise an exception. - - Raises: - :exc:`ResponseCodeError`: - If the response is not OK and ``should_raise`` is True. - """ - if should_raise and response.status >= 400: - try: - response_json = await response.json() - raise ResponseCodeError(response=response, response_json=response_json) - except aiohttp.ContentTypeError: - response_text = await response.text() - raise ResponseCodeError(response=response, response_text=response_text) - - async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """ - Send an HTTP request to the site API and return the JSON response. - - Args: - method: The HTTP method to use. - endpoint: The endpoint to send the request to. - raise_for_status: Whether or not to raise an exception if the response is not OK. - **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`. - - Returns: - The JSON response the API returns. - - Raises: - :exc:`ResponseCodeError`: - If the response is not OK and ``raise_for_status`` is True. - """ - async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Equivalent to :meth:`APIClient.request` with GET passed as the method.""" - return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Equivalent to :meth:`APIClient.request` with PATCH passed as the method.""" - return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Equivalent to :meth:`APIClient.request` with POST passed as the method.""" - return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: - """Equivalent to :meth:`APIClient.request` with PUT passed as the method.""" - return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) - - async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: - """ - Send a DELETE request to the site API and return the JSON response. - - Args: - endpoint: The endpoint to send the request to. - raise_for_status: Whether or not to raise an exception if the response is not OK. - **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`. - - Returns: - The JSON response the API returns, or None if the response is 204 No Content. - """ - async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: - if resp.status == 204: - return None - - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - -__all__ = ['APIClient', 'ResponseCodeError'] diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py deleted file mode 100644 index 09aaa45f..00000000 --- a/botcore/utils/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Useful utilities and tools for Discord bot development.""" - -from botcore.utils import ( - _monkey_patches, - caching, - channel, - commands, - cooldown, - function, - interactions, - logging, - members, - regex, - scheduling, -) -from botcore.utils._extensions import unqualify - - -def apply_monkey_patches() -> None: - """ - Applies all common monkey patches for our bots. - - Patches :obj:`discord.ext.commands.Command` and :obj:`discord.ext.commands.Group` to support root aliases. - A ``root_aliases`` keyword argument is added to these two objects, which is a sequence of alias names - that will act as top-level groups rather than being aliases of the command's group. - - It's stored as an attribute also named ``root_aliases`` - - Patches discord's internal ``send_typing`` method so that it ignores 403 errors from Discord. - When under heavy load Discord has added a CloudFlare worker to this route, which causes 403 errors to be thrown. - """ - _monkey_patches._apply_monkey_patches() - - -__all__ = [ - apply_monkey_patches, - caching, - channel, - commands, - cooldown, - function, - interactions, - logging, - members, - regex, - scheduling, - unqualify, -] - -__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/utils/_extensions.py b/botcore/utils/_extensions.py deleted file mode 100644 index 536a0715..00000000 --- a/botcore/utils/_extensions.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Utilities for loading Discord extensions.""" - -import importlib -import inspect -import pkgutil -import types -from typing import NoReturn - - -def unqualify(name: str) -> str: - """ - Return an unqualified name given a qualified module/package ``name``. - - Args: - name: The module name to unqualify. - - Returns: - The unqualified module name. - """ - return name.rsplit(".", maxsplit=1)[-1] - - -def ignore_module(module: pkgutil.ModuleInfo) -> bool: - """Return whether the module with name `name` should be ignored.""" - return any(name.startswith("_") for name in module.name.split(".")) - - -def walk_extensions(module: types.ModuleType) -> frozenset[str]: - """ - Return all extension names from the given module. - - Args: - module (types.ModuleType): The module to look for extensions in. - - Returns: - A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. - """ - - def on_error(name: str) -> NoReturn: - raise ImportError(name=name) # pragma: no cover - - modules = set() - - for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): - if ignore_module(module_info): - # Ignore modules/packages that have a name starting with an underscore anywhere in their trees. - continue - - if module_info.ispkg: - imported = importlib.import_module(module_info.name) - if not inspect.isfunction(getattr(imported, "setup", None)): - # If it lacks a setup function, it's not an extension. - continue - - modules.add(module_info.name) - - return frozenset(modules) diff --git a/botcore/utils/_monkey_patches.py b/botcore/utils/_monkey_patches.py deleted file mode 100644 index c2f8aa10..00000000 --- a/botcore/utils/_monkey_patches.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Contains all common monkey patches, used to alter discord to fit our needs.""" - -import logging -import typing -from datetime import datetime, timedelta -from functools import partial, partialmethod - -from discord import Forbidden, http -from discord.ext import commands - -log = logging.getLogger(__name__) - - -class _Command(commands.Command): - """ - A :obj:`discord.ext.commands.Command` subclass which supports root aliases. - - A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as - top-level commands rather than being aliases of the command's group. It's stored as an attribute - also named ``root_aliases``. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.root_aliases = kwargs.get("root_aliases", []) - - if not isinstance(self.root_aliases, (list, tuple)): - raise TypeError("Root aliases of a command must be a list or a tuple of strings.") - - -class _Group(commands.Group, _Command): - """ - A :obj:`discord.ext.commands.Group` subclass which supports root aliases. - - A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as - top-level groups rather than being aliases of the command's group. It's stored as an attribute - also named ``root_aliases``. - """ - - -def _patch_typing() -> None: - """ - Sometimes Discord turns off typing events by throwing 403s. - - Handle those issues by patching discord's internal ``send_typing`` method so it ignores 403s in general. - """ - log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!") - - original = http.HTTPClient.send_typing - last_403: typing.Optional[datetime] = None - - async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None: - nonlocal last_403 - if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5): - log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") - return - try: - await original(self, channel_id) - except Forbidden: - last_403 = datetime.utcnow() - log.warning("Got a 403 from typing event!") - - http.HTTPClient.send_typing = honeybadger_type - - -def _apply_monkey_patches() -> None: - """This is surfaced directly in botcore.utils.apply_monkey_patches().""" - commands.command = partial(commands.command, cls=_Command) - commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=_Command) - - commands.group = partial(commands.group, cls=_Group) - commands.GroupMixin.group = partialmethod(commands.GroupMixin.group, cls=_Group) - _patch_typing() diff --git a/botcore/utils/caching.py b/botcore/utils/caching.py deleted file mode 100644 index ac34bb9b..00000000 --- a/botcore/utils/caching.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Utilities related to custom caches.""" - -import functools -import typing -from collections import OrderedDict - - -class AsyncCache: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - - def __init__(self, max_size: int = 128): - """ - Initialise a new :obj:`AsyncCache` instance. - - Args: - max_size: How many items to store in the cache. - """ - self._cache = OrderedDict() - self._max_size = max_size - - def __call__(self, arg_offset: int = 0) -> typing.Callable: - """ - Decorator for async cache. - - Args: - arg_offset: The offset for the position of the key argument. - - Returns: - A decorator to wrap the target function. - """ - - def decorator(function: typing.Callable) -> typing.Callable: - """ - Define the async cache decorator. - - Args: - function: The function to wrap. - - Returns: - The wrapped function. - """ - - @functools.wraps(function) - async def wrapper(*args) -> typing.Any: - """Decorator wrapper for the caching logic.""" - key = args[arg_offset:] - - if key not in self._cache: - if len(self._cache) > self._max_size: - self._cache.popitem(last=False) - - self._cache[key] = await function(*args) - return self._cache[key] - return wrapper - return decorator - - def clear(self) -> None: - """Clear cache instance.""" - self._cache.clear() diff --git a/botcore/utils/channel.py b/botcore/utils/channel.py deleted file mode 100644 index c09d53bf..00000000 --- a/botcore/utils/channel.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Useful helper functions for interacting with various discord channel objects.""" - -import discord -from discord.ext.commands import Bot - -from botcore.utils import logging - -log = logging.get_logger(__name__) - - -def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: - """ - Return whether the given ``channel`` in the the category with the id ``category_id``. - - Args: - channel: The channel to check. - category_id: The category to check for. - - Returns: - A bool depending on whether the channel is in the category. - """ - return getattr(channel, "category_id", None) == category_id - - -async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: - """ - Attempt to get or fetch the given ``channel_id`` from the bots cache, and return it. - - Args: - bot: The :obj:`discord.ext.commands.Bot` instance to use for getting/fetching. - channel_id: The channel to get/fetch. - - Raises: - :exc:`discord.InvalidData` - An unknown channel type was received from Discord. - :exc:`discord.HTTPException` - Retrieving the channel failed. - :exc:`discord.NotFound` - Invalid Channel ID. - :exc:`discord.Forbidden` - You do not have permission to fetch this channel. - - Returns: - The channel from the ID. - """ - log.trace(f"Getting the channel {channel_id}.") - - channel = bot.get_channel(channel_id) - if not channel: - log.debug(f"Channel {channel_id} is not in cache; fetching from API.") - channel = await bot.fetch_channel(channel_id) - - log.trace(f"Channel #{channel} ({channel_id}) retrieved.") - return channel diff --git a/botcore/utils/commands.py b/botcore/utils/commands.py deleted file mode 100644 index 7afd8137..00000000 --- a/botcore/utils/commands.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -from discord import Message -from discord.ext.commands import BadArgument, Context, clean_content - - -async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str: - """ - Cleans a text argument or replied message's content. - - Args: - ctx: The command's context - text: The provided text argument of the command (if given) - - Raises: - :exc:`discord.ext.commands.BadArgument` - `text` wasn't provided and there's no reply message / reply message content. - - Returns: - The cleaned version of `text`, if given, else replied message. - """ - clean_content_converter = clean_content(fix_channel_mentions=True) - - if text: - return await clean_content_converter.convert(ctx, text) - - if ( - (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference - and isinstance(replied_message, Message) # referenced message hasn't been deleted - ): - if not (content := ctx.message.reference.resolved.content): - # The referenced message doesn't have a content (e.g. embed/image), so raise error - raise BadArgument("The referenced message doesn't have a text content.") - - return await clean_content_converter.convert(ctx, content) - - # No text provided, and either no message was referenced or we can't access the content - raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.") diff --git a/botcore/utils/cooldown.py b/botcore/utils/cooldown.py deleted file mode 100644 index 015734d2..00000000 --- a/botcore/utils/cooldown.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Helpers for setting a cooldown on commands.""" - -from __future__ import annotations - -import asyncio -import random -import time -import typing -import weakref -from collections.abc import Awaitable, Callable, Hashable, Iterable -from contextlib import suppress -from dataclasses import dataclass - -import discord -from discord.ext.commands import CommandError, Context - -from botcore.utils import scheduling -from botcore.utils.function import command_wraps - -__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] - -_KEYWORD_SEP_SENTINEL = object() - -_ArgsList = list[object] -_HashableArgsTuple = tuple[Hashable, ...] - -if typing.TYPE_CHECKING: - import typing_extensions - from botcore import BotBase - -P = typing.ParamSpec("P") -"""The command's signature.""" -R = typing.TypeVar("R") -"""The command's return value.""" - - -class CommandOnCooldown(CommandError, typing.Generic[P, R]): - """Raised when a command is invoked while on cooldown.""" - - def __init__( - self, - message: str | None, - function: Callable[P, Awaitable[R]], - /, - *args: P.args, - **kwargs: P.kwargs, - ): - super().__init__(message, function, args, kwargs) - self._function = function - self._args = args - self._kwargs = kwargs - - async def call_without_cooldown(self) -> R: - """ - Run the command this cooldown blocked. - - Returns: - The command's return value. - """ - return await self._function(*self._args, **self._kwargs) - - -@dataclass -class _CooldownItem: - non_hashable_arguments: _ArgsList - timeout_timestamp: float - - -@dataclass -class _SeparatedArguments: - """Arguments separated into their hashable and non-hashable parts.""" - - hashable: _HashableArgsTuple - non_hashable: _ArgsList - - @classmethod - def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self: - """Create a new instance from full call arguments.""" - hashable = list[Hashable]() - non_hashable = list[object]() - - for item in call_arguments: - try: - hash(item) - except TypeError: - non_hashable.append(item) - else: - hashable.append(item) - - return cls(tuple(hashable), non_hashable) - - -class _CommandCooldownManager: - """ - Manage invocation cooldowns for a command through the arguments the command is called with. - - Use `set_cooldown` to set a cooldown, - and `is_on_cooldown` to check for a cooldown for a channel with the given arguments. - A cooldown lasts for `cooldown_duration` seconds. - """ - - def __init__(self, *, cooldown_duration: float): - self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]() - self._cooldown_duration = cooldown_duration - self.cleanup_task = scheduling.create_task( - self._periodical_cleanup(random.uniform(0, 10)), - name="CooldownManager cleanup", - ) - weakref.finalize(self, self.cleanup_task.cancel) - - def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None: - """Set `call_arguments` arguments on cooldown in `channel`.""" - timeout_timestamp = time.monotonic() + self._cooldown_duration - separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) - cooldowns_list = self._cooldowns.setdefault( - (channel, separated_arguments.hashable), - [], - ) - - for item in cooldowns_list: - if item.non_hashable_arguments == separated_arguments.non_hashable: - item.timeout_timestamp = timeout_timestamp - return - - cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp)) - - def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool: - """Check whether `call_arguments` is on cooldown in `channel`.""" - current_time = time.monotonic() - separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) - cooldowns_list = self._cooldowns.get( - (channel, separated_arguments.hashable), - [], - ) - - for item in cooldowns_list: - if item.non_hashable_arguments == separated_arguments.non_hashable: - return item.timeout_timestamp > current_time - return False - - async def _periodical_cleanup(self, initial_delay: float) -> None: - """ - Delete stale items every hour after waiting for `initial_delay`. - - The `initial_delay` ensures cleanups are not running for every command at the same time. - A strong reference to self is only kept while cleanup is running. - """ - weak_self = weakref.ref(self) - del self - - await asyncio.sleep(initial_delay) - while True: - await asyncio.sleep(60 * 60) - weak_self()._delete_stale_items() - - def _delete_stale_items(self) -> None: - """Remove expired items from internal collections.""" - current_time = time.monotonic() - - for key, cooldowns_list in self._cooldowns.copy().items(): - filtered_cooldowns = [ - cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time - ] - - if not filtered_cooldowns: - del self._cooldowns[key] - else: - self._cooldowns[key] = filtered_cooldowns - - -def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]: - return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items()) - - -def block_duplicate_invocations( - *, - cooldown_duration: float = 5, - send_notice: bool = False, - args_preprocessor: Callable[P, Iterable[object]] | None = None, -) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: - """ - Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds. - - Args: - cooldown_duration: Length of the cooldown in seconds. - send_notice: If :obj:`True`, notify the user about the cooldown with a reply. - args_preprocessor: If specified, this function is called with the args and kwargs the function is called with, - its return value is then used to check for the cooldown instead of the raw arguments. - - Returns: - A decorator that adds a wrapper which applies the cooldowns. - - Warning: - The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown. - """ - - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration) - - @command_wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - if args_preprocessor is not None: - all_args = args_preprocessor(*args, **kwargs) - else: - all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command - ctx = typing.cast("Context[BotBase]", args[1]) - - if not isinstance(ctx.channel, discord.DMChannel): - if mgr.is_on_cooldown(ctx.channel, all_args): - if send_notice: - with suppress(discord.NotFound): - await ctx.reply("The command is on cooldown with the given arguments.") - raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs) - mgr.set_cooldown(ctx.channel, all_args) - - return await func(*args, **kwargs) - - return wrapper - - return decorator diff --git a/botcore/utils/function.py b/botcore/utils/function.py deleted file mode 100644 index d89163ec..00000000 --- a/botcore/utils/function.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Utils for manipulating functions.""" - -from __future__ import annotations - -import functools -import types -import typing -from collections.abc import Callable, Sequence, Set - -__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] - - -if typing.TYPE_CHECKING: - _P = typing.ParamSpec("_P") - _R = typing.TypeVar("_R") - - -class GlobalNameConflictError(Exception): - """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" - - -def update_wrapper_globals( - wrapper: Callable[_P, _R], - wrapped: Callable[_P, _R], - *, - ignored_conflict_names: Set[str] = frozenset(), -) -> Callable[_P, _R]: - r""" - Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals. - - For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function - to resolve their values. This breaks for decorators that replace the function because they have - their own globals. - - .. warning:: - This function captures the state of ``wrapped``\'s module's globals when it's called; - changes won't be reflected in the new function's globals. - - Args: - wrapper: The function to wrap. - wrapped: The function to wrap with. - ignored_conflict_names: A set of names to ignore if a conflict between them is found. - - Raises: - :exc:`GlobalNameConflictError`: - If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, - and is not in ``ignored_conflict_names``. - """ - wrapped = typing.cast(types.FunctionType, wrapped) - wrapper = typing.cast(types.FunctionType, wrapper) - - annotation_global_names = ( - ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) - ) - # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. - shared_globals = ( - set(wrapper.__code__.co_names) - & set(annotation_global_names) - & set(wrapped.__globals__) - & set(wrapper.__globals__) - - ignored_conflict_names - ) - if shared_globals: - raise GlobalNameConflictError( - f"wrapper and the wrapped function share the following " - f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " - f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." - ) - - new_globals = wrapper.__globals__.copy() - new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) - return types.FunctionType( - code=wrapper.__code__, - globals=new_globals, - name=wrapper.__name__, - argdefs=wrapper.__defaults__, - closure=wrapper.__closure__, - ) - - -def command_wraps( - wrapped: Callable[_P, _R], - assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, - updated: Sequence[str] = functools.WRAPPER_UPDATES, - *, - ignored_conflict_names: Set[str] = frozenset(), -) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: - r""" - Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. - - See :func:`update_wrapper_globals` for more details on how the globals are updated. - - Args: - wrapped: The function to wrap with. - assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. - updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. - ignored_conflict_names: A set of names to ignore if a conflict between them is found. - - Returns: - A decorator that behaves like :func:`functools.wraps`, - with the wrapper replaced with the function :func:`update_wrapper_globals` returned. - """ # noqa: D200 - def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: - return functools.update_wrapper( - update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), - wrapped, - assigned, - updated, - ) - - return decorator diff --git a/botcore/utils/interactions.py b/botcore/utils/interactions.py deleted file mode 100644 index 26bd92f2..00000000 --- a/botcore/utils/interactions.py +++ /dev/null @@ -1,98 +0,0 @@ -import contextlib -from typing import Optional, Sequence - -from discord import ButtonStyle, Interaction, Message, NotFound, ui - -from botcore.utils.logging import get_logger - -log = get_logger(__name__) - - -class ViewWithUserAndRoleCheck(ui.View): - """ - A view that allows the original invoker and moderators to interact with it. - - Args: - allowed_users: A sequence of user's ids who are allowed to interact with the view. - allowed_roles: A sequence of role ids that are allowed to interact with the view. - timeout: Timeout in seconds from last interaction with the UI before no longer accepting input. - If ``None`` then there is no timeout. - message: The message to remove the view from on timeout. This can also be set with - ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated. - """ - - def __init__( - self, - *, - allowed_users: Sequence[int], - allowed_roles: Sequence[int], - timeout: Optional[float] = 180.0, - message: Optional[Message] = None - ) -> None: - super().__init__(timeout=timeout) - self.allowed_users = allowed_users - self.allowed_roles = allowed_roles - self.message = message - - async def interaction_check(self, interaction: Interaction) -> bool: - """ - Ensure the user clicking the button is the view invoker, or a moderator. - - Args: - interaction: The interaction that occurred. - """ - if interaction.user.id in self.allowed_users: - log.trace( - "Allowed interaction by %s (%d) on %d as they are an allowed user.", - interaction.user, - interaction.user.id, - interaction.message.id, - ) - return True - - if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): - log.trace( - "Allowed interaction by %s (%d)on %d as they have an allowed role.", - interaction.user, - interaction.user.id, - interaction.message.id, - ) - return True - - await interaction.response.send_message("This is not your button to click!", ephemeral=True) - return False - - async def on_timeout(self) -> None: - """Remove the view from ``self.message`` if set.""" - if self.message: - with contextlib.suppress(NotFound): - # Cover the case where this message has already been deleted by external means - await self.message.edit(view=None) - - -class DeleteMessageButton(ui.Button): - """ - A button that can be added to a view to delete the message containing the view on click. - - This button itself carries out no interaction checks, these should be done by the parent view. - - See :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks. - - Args: - style (:literal-url:`ButtonStyle `): - The style of the button, set to ``ButtonStyle.secondary`` if not specified. - label: The label of the button, set to "Delete" if not specified. - """ # noqa: E501 - - def __init__( - self, - *, - style: ButtonStyle = ButtonStyle.secondary, - label: str = "Delete", - **kwargs - ): - super().__init__(style=style, label=label, **kwargs) - - async def callback(self, interaction: Interaction) -> None: - """Delete the original message on button click.""" - await interaction.message.delete() diff --git a/botcore/utils/logging.py b/botcore/utils/logging.py deleted file mode 100644 index 1f1c8bac..00000000 --- a/botcore/utils/logging.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Common logging related functions.""" - -import logging -import typing - -if typing.TYPE_CHECKING: - LoggerClass = logging.Logger -else: - LoggerClass = logging.getLoggerClass() - -TRACE_LEVEL = 5 - - -class CustomLogger(LoggerClass): - """Custom implementation of the :obj:`logging.Logger` class with an added :obj:`trace` method.""" - - def trace(self, msg: str, *args, **kwargs) -> None: - """ - Log the given message with the severity ``"TRACE"``. - - To pass exception information, use the keyword argument exc_info with a true value: - - .. code-block:: py - - logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) - - Args: - msg: The message to be logged. - args, kwargs: Passed to the base log function as is. - """ - if self.isEnabledFor(TRACE_LEVEL): - self.log(TRACE_LEVEL, msg, *args, **kwargs) - - -def get_logger(name: typing.Optional[str] = None) -> CustomLogger: - """ - Utility to make mypy recognise that logger is of type :obj:`CustomLogger`. - - Args: - name: The name given to the logger. - - Returns: - An instance of the :obj:`CustomLogger` class. - """ - return typing.cast(CustomLogger, logging.getLogger(name)) - - -# Setup trace level logging so that we can use it within botcore. -logging.TRACE = TRACE_LEVEL -logging.setLoggerClass(CustomLogger) -logging.addLevelName(TRACE_LEVEL, "TRACE") diff --git a/botcore/utils/members.py b/botcore/utils/members.py deleted file mode 100644 index 1536a8d1..00000000 --- a/botcore/utils/members.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Useful helper functions for interactin with :obj:`discord.Member` objects.""" -import typing -from collections import abc - -import discord - -from botcore.utils import logging - -log = logging.get_logger(__name__) - - -async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: - """ - Attempt to get a member from cache; on failure fetch from the API. - - Returns: - The :obj:`discord.Member` or :obj:`None` to indicate the member could not be found. - """ - if member := guild.get_member(member_id): - log.trace(f"{member} retrieved from cache.") - else: - try: - member = await guild.fetch_member(member_id) - except discord.errors.NotFound: - log.trace(f"Failed to fetch {member_id} from API.") - return None - log.trace(f"{member} fetched from API.") - return member - - -async def handle_role_change( - member: discord.Member, - coro: typing.Callable[[discord.Role], abc.Coroutine], - role: discord.Role -) -> None: - """ - Await the given ``coro`` with ``role`` as the sole argument. - - Handle errors that we expect to be raised from - :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`. - - Args: - member: The member that is being modified for logging purposes. - coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`. - role: The role to be passed to ``coro``. - """ - try: - await coro(role) - except discord.NotFound: - log.error(f"Failed to change role for {member} ({member.id}): member not found") - except discord.Forbidden: - log.error( - f"Forbidden to change role for {member} ({member.id}); " - f"possibly due to role hierarchy" - ) - except discord.HTTPException as e: - log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py deleted file mode 100644 index de82a1ed..00000000 --- a/botcore/utils/regex.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Common regular expressions.""" - -import re - -DISCORD_INVITE = re.compile( - r"(https?://)?(www\.)?" # Optional http(s) and www. - r"(discord([.,]|dot)gg|" # Could be discord.gg/ - r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ - r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ - r"discord([.,]|dot)me|" # or discord.me - r"discord([.,]|dot)li|" # or discord.li - r"discord([.,]|dot)io|" # or discord.io. - r"((?\S+)", # the invite code itself - flags=re.IGNORECASE -) -""" -Regex for Discord server invites. - -.. warning:: - This regex pattern will capture until a whitespace, if you are to use the 'invite' capture group in - any HTTP requests or similar. Please ensure you sanitise the output using something - such as :func:`urllib.parse.quote`. - -:meta hide-value: -""" - -FORMATTED_CODE_REGEX = re.compile( - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)", # match the exact same delimiter from the start again - flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -""" -Regex for formatted code, using Discord's code blocks. - -:meta hide-value: -""" - -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - flags=re.DOTALL # "." also matches newlines -) -""" -Regex for raw code, *not* using Discord's code blocks. - -:meta hide-value: -""" diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py deleted file mode 100644 index 9517df6d..00000000 --- a/botcore/utils/scheduling.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Generic python scheduler.""" - -import asyncio -import contextlib -import inspect -import typing -from collections import abc -from datetime import datetime -from functools import partial - -from botcore.utils import logging - - -class Scheduler: - """ - Schedule the execution of coroutines and keep track of them. - - When instantiating a :obj:`Scheduler`, a name must be provided. This name is used to distinguish the - instance's log messages from other instances. Using the name of the class or module containing - the instance is suggested. - - Coroutines can be scheduled immediately with :obj:`schedule` or in the future with :obj:`schedule_at` - or :obj:`schedule_later`. A unique ID is required to be given in order to keep track of the - resulting Tasks. Any scheduled task can be cancelled prematurely using :obj:`cancel` by providing - the same ID used to schedule it. - - The ``in`` operator is supported for checking if a task with a given ID is currently scheduled. - - Any exception raised in a scheduled task is logged when the task is done. - """ - - def __init__(self, name: str): - """ - Initialize a new :obj:`Scheduler` instance. - - Args: - name: The name of the :obj:`Scheduler`. Used in logging, and namespacing. - """ - self.name = name - - self._log = logging.get_logger(f"{__name__}.{name}") - self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {} - - def __contains__(self, task_id: abc.Hashable) -> bool: - """ - Return :obj:`True` if a task with the given ``task_id`` is currently scheduled. - - Args: - task_id: The task to look for. - - Returns: - :obj:`True` if the task was found. - """ - return task_id in self._scheduled_tasks - - def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: - """ - Schedule the execution of a ``coroutine``. - - If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - self._log.trace(f"Scheduling task #{task_id}...") - - msg = f"Cannot schedule an already started coroutine for #{task_id}" - assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg - - if task_id in self._scheduled_tasks: - self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") - coroutine.close() - return - - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") - task.add_done_callback(partial(self._task_done_callback, task_id)) - - self._scheduled_tasks[task_id] = task - self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - - def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: - """ - Schedule ``coroutine`` to be executed at the given ``time``. - - If ``time`` is timezone aware, then use that timezone to calculate now() when subtracting. - If ``time`` is naïve, then use UTC. - - If ``time`` is in the past, schedule ``coroutine`` immediately. - - If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - time: The time to start the task. - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() - delay = (time - now_datetime).total_seconds() - if delay > 0: - coroutine = self._await_later(delay, task_id, coroutine) - - self.schedule(task_id, coroutine) - - def schedule_later( - self, - delay: typing.Union[int, float], - task_id: abc.Hashable, - coroutine: abc.Coroutine - ) -> None: - """ - Schedule ``coroutine`` to be executed after ``delay`` seconds. - - If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - delay: How long to wait before starting the task. - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - - def cancel(self, task_id: abc.Hashable) -> None: - """ - Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist. - - Args: - task_id: The task's unique ID. - """ - self._log.trace(f"Cancelling task #{task_id}...") - - try: - task = self._scheduled_tasks.pop(task_id) - except KeyError: - self._log.warning(f"Failed to unschedule {task_id} (no task found).") - else: - task.cancel() - - self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") - - def cancel_all(self) -> None: - """Unschedule all known tasks.""" - self._log.debug("Unscheduling all tasks") - - for task_id in self._scheduled_tasks.copy(): - self.cancel(task_id) - - async def _await_later( - self, - delay: typing.Union[int, float], - task_id: abc.Hashable, - coroutine: abc.Coroutine - ) -> None: - """Await ``coroutine`` after ``delay`` seconds.""" - try: - self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") - await asyncio.sleep(delay) - - # Use asyncio.shield to prevent the coroutine from cancelling itself. - self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") - await asyncio.shield(coroutine) - finally: - # Close it to prevent unawaited coroutine warnings, - # which would happen if the task was cancelled during the sleep. - # Only close it if it's not been awaited yet. This check is important because the - # coroutine may cancel this task, which would also trigger the finally block. - state = inspect.getcoroutinestate(coroutine) - if state == "CORO_CREATED": - self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") - coroutine.close() - else: - self._log.debug(f"Finally block reached for #{task_id}; {state=}") - - def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None: - """ - Delete the task and raise its exception if one exists. - - If ``done_task`` and the task associated with ``task_id`` are different, then the latter - will not be deleted. In this case, a new task was likely rescheduled with the same ID. - """ - self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") - - scheduled_task = self._scheduled_tasks.get(task_id) - - if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and is the same as the done task. - # Since this is the done callback, the task is already done so no need to cancel it. - self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") - del self._scheduled_tasks[task_id] - elif scheduled_task: - # A new task was likely rescheduled with the same ID. - self._log.debug( - f"The scheduled task #{task_id} {id(scheduled_task)} " - f"and the done task {id(done_task)} differ." - ) - elif not done_task.cancelled(): - self._log.warning( - f"Task #{task_id} not found while handling task {id(done_task)}! " - f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." - ) - - with contextlib.suppress(asyncio.CancelledError): - exception = done_task.exception() - # Log the exception if one exists. - if exception: - self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) - - -TASK_RETURN = typing.TypeVar("TASK_RETURN") - - -def create_task( - coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN], - *, - suppressed_exceptions: tuple[type[Exception], ...] = (), - event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, - **kwargs, -) -> asyncio.Task[TASK_RETURN]: - """ - Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task. - - If the ``event_loop`` kwarg is provided, the task is created from that event loop, - otherwise the running loop is used. - - Args: - coro: The function to call. - suppressed_exceptions: Exceptions to be handled by the task. - event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. - kwargs: Passed to :py:func:`asyncio.create_task`. - - Returns: - asyncio.Task: The wrapped task. - """ - if event_loop is not None: - task = event_loop.create_task(coro, **kwargs) - else: - task = asyncio.create_task(coro, **kwargs) - task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) - return task - - -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None: - """Retrieve and log the exception raised in ``task`` if one exists.""" - with contextlib.suppress(asyncio.CancelledError): - exception = task.exception() - # Log the exception if one exists. - if exception and not isinstance(exception, suppressed_exceptions): - log = logging.get_logger(__name__) - log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/dev/README.rst b/dev/README.rst index ae4f3adc..9428d788 100644 --- a/dev/README.rst +++ b/dev/README.rst @@ -3,7 +3,7 @@ Local Development & Testing To test your features locally, there are a few possible approaches: -1. Install your local copy of botcore into a pre-existing project such as bot +1. Install your local copy of pydis_core into a pre-existing project such as bot 2. Use the provided template from the :repo-file:`dev/bot ` folder See below for more info on both approaches. @@ -17,12 +17,12 @@ vary by the feature you're working on. Option 1 -------- 1. Navigate to the project you want to install bot-core in, such as bot or sir-lancebot -2. Run ``pip install /path/to/botcore`` in the project's environment +2. Run ``pip install /path/to/pydis_core`` in the project's environment - The path provided to install should be the root directory of this project on your machine. That is, the folder which contains the ``pyproject.toml`` file. - Make sure to install in the correct environment. Most Python Discord projects use - poetry, so you can run ``poetry run pip install /path/to/botcore``. + poetry, so you can run ``poetry run pip install /path/to/pydis_core``. 3. You can now use features from your local bot-core changes. To load new changes, run the install command again. diff --git a/dev/bot/__init__.py b/dev/bot/__init__.py index 71871209..6ee1ae47 100644 --- a/dev/bot/__init__.py +++ b/dev/bot/__init__.py @@ -3,7 +3,7 @@ import logging import os import sys -import botcore +import pydis_core if os.name == "nt": # Change the event loop policy on Windows to avoid exceptions on exit @@ -15,7 +15,7 @@ logging.getLogger().setLevel(logging.DEBUG) logging.getLogger("discord").setLevel(logging.ERROR) -class Bot(botcore.BotBase): +class Bot(pydis_core.BotBase): """Sample Bot implementation.""" async def setup_hook(self) -> None: diff --git a/dev/bot/__main__.py b/dev/bot/__main__.py index 42d212c2..1b1a551a 100644 --- a/dev/bot/__main__.py +++ b/dev/bot/__main__.py @@ -6,11 +6,11 @@ import discord import dotenv from discord.ext import commands -import botcore +import pydis_core from . import Bot dotenv.load_dotenv() -botcore.utils.apply_monkey_patches() +pydis_core.utils.apply_monkey_patches() roles = os.getenv("ALLOWED_ROLES") roles = [int(role) for role in roles.split(",")] if roles else [] diff --git a/docker-compose.yaml b/docker-compose.yaml index af882428..078ee6bb 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -69,7 +69,7 @@ services: context: . dockerfile: dev/Dockerfile volumes: # Don't do .:/app here to ensure project venv from host doens't overwrite venv in image - - ./botcore:/app/botcore:ro + - ./pydis_core:/app/pydis_core:ro - ./bot:/app/bot:ro tty: true depends_on: diff --git a/docs/changelog.rst b/docs/changelog.rst index 3e3c7149..819e7d38 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog ========= +- :release:`9.0.0 <5th November 2022>` +- :breaking:`157` Rename project to pydis_core to allow for publishing to pypi. + + - :release:`8.2.1 <18th September 2022>` - :bug:`138` Bump Discord.py to :literal-url:`2.0.1 `. @@ -13,7 +17,7 @@ Changelog - :release:`8.1.0 <16th August 2022>` -- :support:`124` Updated :obj:`botcore.utils.regex.DISCORD_INVITE` regex to optionally match leading "http[s]" and "www". +- :support:`124` Updated :obj:`pydis_core.utils.regex.DISCORD_INVITE` regex to optionally match leading "http[s]" and "www". - :release:`8.0.0 <27th July 2022>` @@ -28,16 +32,16 @@ Changelog - :release:`7.4.0 <17th July 2022>` -- :feature:`106` Add an optional ``message`` attr to :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck`. On view timeout, this message has its view removed if set. +- :feature:`106` Add an optional ``message`` attr to :obj:`pydis_core.utils.interactions.ViewWithUserAndRoleCheck`. On view timeout, this message has its view removed if set. - :release:`7.3.1 <16th July 2022>` -- :bug:`104` Fix :obj:`botcore.utils.interactions.DeleteMessageButton` not working due to using wrong delete method. +- :bug:`104` Fix :obj:`pydis_core.utils.interactions.DeleteMessageButton` not working due to using wrong delete method. - :release:`7.3.0 <16th July 2022>` -- :feature:`103` Add a generic view :obj:`botcore.utils.interactions.ViewWithUserAndRoleCheck` that only allows specified users and roles to interaction with it -- :feature:`103` Add a button :obj:`botcore.utils.interactions.DeleteMessageButton` that deletes the message attached to its parent view. +- :feature:`103` Add a generic view :obj:`pydis_core.utils.interactions.ViewWithUserAndRoleCheck` that only allows specified users and roles to interaction with it +- :feature:`103` Add a button :obj:`pydis_core.utils.interactions.DeleteMessageButton` that deletes the message attached to its parent view. - :release:`7.2.2 <9th July 2022>` @@ -46,7 +50,7 @@ Changelog - :release:`7.2.1 <30th June 2022>` - :bug:`96` Fix attempts to connect to ``BotBase.statsd_url`` when it is None. -- :bug:`91` Fix incorrect docstring for ``botcore.utils.member.handle_role_change``. +- :bug:`91` Fix incorrect docstring for ``pydis_core.utils.member.handle_role_change``. - :bug:`91` Pass missing self parameter to ``BotBase.ping_services``. - :bug:`91` Add missing await to ``BotBase.ping_services`` in some cases. @@ -96,7 +100,7 @@ Changelog - :release:`6.1.0 <20th April 2022>` -- :feature:`65` Add ``unqualify`` to the ``botcore.utils`` namespace for use in bots that manipulate extensions. +- :feature:`65` Add ``unqualify`` to the ``pydis_core.utils`` namespace for use in bots that manipulate extensions. - :release:`6.0.0 <19th April 2022>` @@ -112,7 +116,7 @@ Changelog Feature 63 Needs to be explicitly included above because it was improperly released within a bugfix version instead of a minor release -- :feature:`63` Allow passing an ``api_client`` to ``BotBase.__init__`` to specify the ``botcore.site_api.APIClient`` instance to use. +- :feature:`63` Allow passing an ``api_client`` to ``BotBase.__init__`` to specify the ``pydis_core.site_api.APIClient`` instance to use. - :release:`5.0.3 <18th April 2022>` @@ -140,11 +144,11 @@ Changelog - :release:`3.0.1 <5th March 2022>` -- :bug:`37` Setup log tracing when ``botcore.utils.logging`` is imported so that it can be used within botcore functions. +- :bug:`37` Setup log tracing when ``pydis_core.utils.logging`` is imported so that it can be used within pydis_core functions. - :release:`3.0.0 <3rd March 2022>` -- :breaking:`35` Move ``apply_monkey_patches()`` directly to `botcore.utils` namespace. +- :breaking:`35` Move ``apply_monkey_patches()`` directly to `pydis_core.utils` namespace. - :release:`2.1.0 <24th February 2022>` @@ -152,7 +156,7 @@ Changelog - :release:`2.0.0 <22nd February 2022>` -- :breaking:`35` Moved regex to ``botcore.utils`` namespace +- :breaking:`35` Moved regex to ``pydis_core.utils`` namespace - :breaking:`32` Migrate from discord.py 2.0a0 to disnake. - :feature:`32` Add common monkey patches. - :feature:`29` Port many common utilities from our bots: diff --git a/docs/index.rst b/docs/index.rst index aee7b269..259d01cc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,7 +11,7 @@ Reference :maxdepth: 4 :caption: Modules: - output/botcore + output/pydis_core .. toctree:: :caption: Other: diff --git a/docs/utils.py b/docs/utils.py index c8bbc895..e7295798 100644 --- a/docs/utils.py +++ b/docs/utils.py @@ -134,12 +134,12 @@ def cleanup() -> None: included = __get_included() for file in (get_build_root() / "docs" / "output").iterdir(): - if file.name in ("botcore.rst", "botcore.exts.rst", "botcore.utils.rst") and file.name in included: + if file.name in ("pydis_core.rst", "pydis_core.exts.rst", "pydis_core.utils.rst") and file.name in included: content = file.read_text(encoding="utf-8").splitlines(keepends=True) # Rename the extension to be less wordy - # Example: botcore.exts -> Botcore Exts - title = content[0].split()[0].strip().replace("botcore.", "").replace(".", " ").title() + # Example: pydis_core.exts -> pydis_core Exts + title = content[0].split()[0].strip().replace("pydis_core.", "").replace(".", " ").title() title = f"{title}\n{'=' * len(title)}\n\n" content = title, *content[3:] @@ -147,7 +147,7 @@ def cleanup() -> None: elif file.name in included: # Clean up the submodule name so it's just the name without the top level module name - # example: `botcore.regex module` -> `regex` + # example: `pydis_core.regex module` -> `regex` lines = file.read_text(encoding="utf-8").splitlines(keepends=True) lines[0] = lines[0].replace("module", "").strip().split(".")[-1] + "\n" file.write_text("".join(lines)) @@ -164,7 +164,7 @@ def cleanup() -> None: def build_api_doc() -> None: """Generate auto-module directives using apidoc.""" - cmd = os.getenv("APIDOC_COMMAND") or "sphinx-apidoc -o docs/output botcore -feM" + cmd = os.getenv("APIDOC_COMMAND") or "sphinx-apidoc -o docs/output pydis_core -feM" cmd = cmd.split() build_root = get_build_root() @@ -196,7 +196,7 @@ def __get_included() -> set[str]: return _modules - return get_all_from_module("botcore") + return get_all_from_module("pydis_core") def reorder_release_entries(release_list: list[releases.Release]) -> None: diff --git a/pydis_core/__init__.py b/pydis_core/__init__.py new file mode 100644 index 00000000..a09feeaa --- /dev/null +++ b/pydis_core/__init__.py @@ -0,0 +1,15 @@ +"""Useful utilities and tools for Discord bot development.""" + +from pydis_core import async_stats, exts, site_api, utils +from pydis_core._bot import BotBase, StartupError + +__all__ = [ + async_stats, + BotBase, + exts, + utils, + site_api, + StartupError, +] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/pydis_core/_bot.py b/pydis_core/_bot.py new file mode 100644 index 00000000..56814f27 --- /dev/null +++ b/pydis_core/_bot.py @@ -0,0 +1,288 @@ +import asyncio +import socket +import types +import warnings +from contextlib import suppress +from typing import Optional + +import aiohttp +import discord +from discord.ext import commands + +from pydis_core.async_stats import AsyncStatsClient +from pydis_core.site_api import APIClient +from pydis_core.utils import scheduling +from pydis_core.utils._extensions import walk_extensions +from pydis_core.utils.logging import get_logger + +try: + from async_rediscache import RedisSession + REDIS_AVAILABLE = True +except ImportError: + RedisSession = None + REDIS_AVAILABLE = False + +log = get_logger() + + +class StartupError(Exception): + """Exception class for startup errors.""" + + def __init__(self, base: Exception): + super().__init__() + self.exception = base + + +class BotBase(commands.Bot): + """A sub-class that implements many common features that Python Discord bots use.""" + + def __init__( + self, + *args, + guild_id: int, + allowed_roles: list, + http_session: aiohttp.ClientSession, + redis_session: Optional[RedisSession] = None, + api_client: Optional[APIClient] = None, + statsd_url: Optional[str] = None, + **kwargs, + ): + """ + Initialise the base bot instance. + + Args: + guild_id: The ID of the guild used for :func:`wait_until_guild_available`. + allowed_roles: A list of role IDs that the bot is allowed to mention. + http_session (aiohttp.ClientSession): The session to use for the bot. + redis_session: The `async_rediscache.RedisSession`_ to use for the bot. + api_client: The :obj:`pydis_core.site_api.APIClient` instance to use for the bot. + statsd_url: The URL of the statsd server to use for the bot. If not given, + a dummy statsd client will be created. + + .. _async_rediscache.RedisSession: https://github.com/SebastiaanZ/async-rediscache#creating-a-redissession + """ + super().__init__( + *args, + allowed_roles=allowed_roles, + **kwargs, + ) + + self.guild_id = guild_id + self.http_session = http_session + self.api_client = api_client + self.statsd_url = statsd_url + + if redis_session and not REDIS_AVAILABLE: + warnings.warn("redis_session kwarg passed, but async-rediscache not installed!") + elif redis_session: + self.redis_session = redis_session + + self._resolver: Optional[aiohttp.AsyncResolver] = None + self._connector: Optional[aiohttp.TCPConnector] = None + + self._statsd_timerhandle: Optional[asyncio.TimerHandle] = None + self._guild_available: Optional[asyncio.Event] = None + + self.stats: Optional[AsyncStatsClient] = None + + self.all_extensions: Optional[frozenset[str]] = None + + def _connect_statsd( + self, + statsd_url: str, + loop: asyncio.AbstractEventLoop, + retry_after: int = 2, + attempt: int = 1 + ) -> None: + """Callback used to retry a connection to statsd if it should fail.""" + if attempt >= 8: + log.error( + "Reached 8 attempts trying to reconnect AsyncStatsClient to %s. " + "Aborting and leaving the dummy statsd client in place.", + statsd_url, + ) + return + + try: + self.stats = AsyncStatsClient(loop, statsd_url, 8125, prefix="bot") + except socket.gaierror: + log.warning(f"Statsd client failed to connect (Attempt(s): {attempt})") + # Use a fallback strategy for retrying, up to 8 times. + self._statsd_timerhandle = loop.call_later( + retry_after, + self._connect_statsd, + statsd_url, + retry_after * 2, + attempt + 1 + ) + + async def load_extensions(self, module: types.ModuleType) -> None: + """ + Load all the extensions within the given module and save them to ``self.all_extensions``. + + This should be ran in a task on the event loop to avoid deadlocks caused by ``wait_for`` calls. + """ + await self.wait_until_guild_available() + self.all_extensions = walk_extensions(module) + + for extension in self.all_extensions: + scheduling.create_task(self.load_extension(extension)) + + def _add_root_aliases(self, command: commands.Command) -> None: + """Recursively add root aliases for ``command`` and any of its subcommands.""" + if isinstance(command, commands.Group): + for subcommand in command.commands: + self._add_root_aliases(subcommand) + + for alias in getattr(command, "root_aliases", ()): + if alias in self.all_commands: + raise commands.CommandRegistrationError(alias, alias_conflict=True) + + self.all_commands[alias] = command + + def _remove_root_aliases(self, command: commands.Command) -> None: + """Recursively remove root aliases for ``command`` and any of its subcommands.""" + if isinstance(command, commands.Group): + for subcommand in command.commands: + self._remove_root_aliases(subcommand) + + for alias in getattr(command, "root_aliases", ()): + self.all_commands.pop(alias, None) + + async def add_cog(self, cog: commands.Cog) -> None: + """Add the given ``cog`` to the bot and log the operation.""" + await super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def add_command(self, command: commands.Command) -> None: + """Add ``command`` as normal and then add its root aliases to the bot.""" + super().add_command(command) + self._add_root_aliases(command) + + def remove_command(self, name: str) -> Optional[commands.Command]: + """ + Remove a command/alias as normal and then remove its root aliases from the bot. + + Individual root aliases cannot be removed by this function. + To remove them, either remove the entire command or manually edit `bot.all_commands`. + """ + command = super().remove_command(name) + if command is None: + # Even if it's a root alias, there's no way to get the Bot instance to remove the alias. + return None + + self._remove_root_aliases(command) + return command + + def clear(self) -> None: + """Not implemented! Re-instantiate the bot instead of attempting to re-use a closed one.""" + raise NotImplementedError("Re-using a Bot object after closing it is not supported.") + + async def on_guild_unavailable(self, guild: discord.Guild) -> None: + """Clear the internal guild available event when self.guild_id becomes unavailable.""" + if guild.id != self.guild_id: + return + + self._guild_available.clear() + + async def on_guild_available(self, guild: discord.Guild) -> None: + """ + Set the internal guild available event when self.guild_id becomes available. + + If the cache appears to still be empty (no members, no channels, or no roles), the event + will not be set and `guild_available_but_cache_empty` event will be emitted. + """ + if guild.id != self.guild_id: + return + + if not guild.roles or not guild.members or not guild.channels: + msg = "Guild available event was dispatched but the cache appears to still be empty!" + await self.log_to_dev_log(msg) + return + + self._guild_available.set() + + async def log_to_dev_log(self, message: str) -> None: + """Log the given message to #dev-log.""" + ... + + async def wait_until_guild_available(self) -> None: + """ + Wait until the guild that matches the ``guild_id`` given at init is available (and the cache is ready). + + The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE + gateway event before giving up and thus not populating the cache for unavailable guilds. + """ + await self._guild_available.wait() + + async def setup_hook(self) -> None: + """ + An async init to startup generic services. + + Connects to statsd, and calls + :func:`AsyncStatsClient.create_socket ` + and :func:`ping_services`. + """ + loop = asyncio.get_running_loop() + + self._guild_available = asyncio.Event() + + self._resolver = aiohttp.AsyncResolver() + self._connector = aiohttp.TCPConnector( + resolver=self._resolver, + family=socket.AF_INET, + ) + self.http.connector = self._connector + + if getattr(self, "redis_session", False) and not self.redis_session.valid: + # If the RedisSession was somehow closed, we try to reconnect it + # here. Normally, this shouldn't happen. + await self.redis_session.connect(ping=True) + + # Create dummy stats client first, in case `statsd_url` is unreachable or None + self.stats = AsyncStatsClient(loop, "127.0.0.1") + if self.statsd_url: + self._connect_statsd(self.statsd_url, loop) + + await self.stats.create_socket() + + try: + await self.ping_services() + except Exception as e: + raise StartupError(e) + + async def ping_services(self) -> None: + """Ping all required services on setup to ensure they are up before starting.""" + ... + + async def close(self) -> None: + """Close the Discord connection, and the aiohttp session, connector, statsd client, and resolver.""" + # Done before super().close() to allow tasks finish before the HTTP session closes. + for ext in list(self.extensions): + with suppress(Exception): + await self.unload_extension(ext) + + for cog in list(self.cogs): + with suppress(Exception): + await self.remove_cog(cog) + + # Now actually do full close of bot + await super().close() + + if self.api_client: + await self.api_client.close() + + if self.http_session: + await self.http_session.close() + + if self._connector: + await self._connector.close() + + if self._resolver: + await self._resolver.close() + + if getattr(self.stats, "_transport", False): + self.stats._transport.close() + + if self._statsd_timerhandle: + self._statsd_timerhandle.cancel() diff --git a/pydis_core/async_stats.py b/pydis_core/async_stats.py new file mode 100644 index 00000000..411325e3 --- /dev/null +++ b/pydis_core/async_stats.py @@ -0,0 +1,57 @@ +"""An async transport method for statsd communication.""" + +import asyncio +import socket +from typing import Optional + +from statsd.client.base import StatsClientBase + +from pydis_core.utils import scheduling + + +class AsyncStatsClient(StatsClientBase): + """An async implementation of :obj:`statsd.client.base.StatsClientBase` that supports async stat communication.""" + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + host: str = 'localhost', + port: int = 8125, + prefix: str = None + ): + """ + Create a new :obj:`AsyncStatsClient`. + + Args: + loop (asyncio.AbstractEventLoop): The event loop to use when creating the + :obj:`asyncio.loop.create_datagram_endpoint`. + host: The host to connect to. + port: The port to connect to. + prefix: The prefix to use for all stats. + """ + _, _, _, _, addr = socket.getaddrinfo( + host, port, socket.AF_INET, socket.SOCK_DGRAM + )[0] + self._addr = addr + self._prefix = prefix + self._loop = loop + self._transport: Optional[asyncio.DatagramTransport] = None + + async def create_socket(self) -> None: + """Use :obj:`asyncio.loop.create_datagram_endpoint` from the loop given on init to create a socket.""" + self._transport, _ = await self._loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + family=socket.AF_INET, + remote_addr=self._addr + ) + + def _send(self, data: str) -> None: + """Start an async task to send data to statsd.""" + scheduling.create_task(self._async_send(data), event_loop=self._loop) + + async def _async_send(self, data: str) -> None: + """Send data to the statsd server using the async transport.""" + self._transport.sendto(data.encode('ascii'), self._addr) + + +__all__ = ['AsyncStatsClient'] diff --git a/pydis_core/exts/__init__.py b/pydis_core/exts/__init__.py new file mode 100644 index 00000000..afd56166 --- /dev/null +++ b/pydis_core/exts/__init__.py @@ -0,0 +1,4 @@ +"""Reusable Discord cogs.""" +__all__ = [] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/pydis_core/site_api.py b/pydis_core/site_api.py new file mode 100644 index 00000000..c17d2642 --- /dev/null +++ b/pydis_core/site_api.py @@ -0,0 +1,157 @@ +"""An API wrapper around the Site API.""" + +import asyncio +from typing import Optional +from urllib.parse import quote as quote_url + +import aiohttp + +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) + + +class ResponseCodeError(ValueError): + """Raised in :meth:`APIClient.request` when a non-OK HTTP response is received.""" + + def __init__( + self, + response: aiohttp.ClientResponse, + response_json: Optional[dict] = None, + response_text: Optional[str] = None + ): + """ + Initialize a new :obj:`ResponseCodeError` instance. + + Args: + response (:obj:`aiohttp.ClientResponse`): The response object from the request. + response_json: The JSON response returned from the request, if any. + response_text: The text of the request, if any. + """ + self.status = response.status + self.response_json = response_json or {} + self.response_text = response_text + self.response = response + + def __str__(self): + """Return a string representation of the error.""" + response = self.response_json or self.response_text + return f"Status: {self.status} Response: {response}" + + +class APIClient: + """A wrapper for the Django Site API.""" + + session: Optional[aiohttp.ClientSession] = None + loop: asyncio.AbstractEventLoop = None + + def __init__(self, site_api_url: str, site_api_token: str, **session_kwargs): + """ + Initialize a new :obj:`APIClient` instance. + + Args: + site_api_url: The URL of the site API. + site_api_token: The token to use for authentication. + session_kwargs: Keyword arguments to pass to the :obj:`aiohttp.ClientSession` constructor. + """ + self.site_api_url = site_api_url + + auth_headers = { + 'Authorization': f"Token {site_api_token}" + } + + if 'headers' in session_kwargs: + session_kwargs['headers'].update(auth_headers) + else: + session_kwargs['headers'] = auth_headers + + # aiohttp will complain if APIClient gets instantiated outside a coroutine. Thankfully, we + # don't and shouldn't need to do that, so we can avoid scheduling a task to create it. + self.session = aiohttp.ClientSession(**session_kwargs) + + def _url_for(self, endpoint: str) -> str: + return f"{self.site_api_url}/{quote_url(endpoint)}" + + async def close(self) -> None: + """Close the aiohttp session.""" + await self.session.close() + + @staticmethod + async def maybe_raise_for_status(response: aiohttp.ClientResponse, should_raise: bool) -> None: + """ + Raise :exc:`ResponseCodeError` for non-OK response if an exception should be raised. + + Args: + response (:obj:`aiohttp.ClientResponse`): The response to check. + should_raise: Whether or not to raise an exception. + + Raises: + :exc:`ResponseCodeError`: + If the response is not OK and ``should_raise`` is True. + """ + if should_raise and response.status >= 400: + try: + response_json = await response.json() + raise ResponseCodeError(response=response, response_json=response_json) + except aiohttp.ContentTypeError: + response_text = await response.text() + raise ResponseCodeError(response=response, response_text=response_text) + + async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """ + Send an HTTP request to the site API and return the JSON response. + + Args: + method: The HTTP method to use. + endpoint: The endpoint to send the request to. + raise_for_status: Whether or not to raise an exception if the response is not OK. + **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`. + + Returns: + The JSON response the API returns. + + Raises: + :exc:`ResponseCodeError`: + If the response is not OK and ``raise_for_status`` is True. + """ + async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: + await self.maybe_raise_for_status(resp, raise_for_status) + return await resp.json() + + async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Equivalent to :meth:`APIClient.request` with GET passed as the method.""" + return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) + + async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Equivalent to :meth:`APIClient.request` with PATCH passed as the method.""" + return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) + + async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Equivalent to :meth:`APIClient.request` with POST passed as the method.""" + return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) + + async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Equivalent to :meth:`APIClient.request` with PUT passed as the method.""" + return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) + + async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: + """ + Send a DELETE request to the site API and return the JSON response. + + Args: + endpoint: The endpoint to send the request to. + raise_for_status: Whether or not to raise an exception if the response is not OK. + **kwargs: Any extra keyword arguments to pass to :func:`aiohttp.request`. + + Returns: + The JSON response the API returns, or None if the response is 204 No Content. + """ + async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: + if resp.status == 204: + return None + + await self.maybe_raise_for_status(resp, raise_for_status) + return await resp.json() + + +__all__ = ['APIClient', 'ResponseCodeError'] diff --git a/pydis_core/utils/__init__.py b/pydis_core/utils/__init__.py new file mode 100644 index 00000000..0542231e --- /dev/null +++ b/pydis_core/utils/__init__.py @@ -0,0 +1,50 @@ +"""Useful utilities and tools for Discord bot development.""" + +from pydis_core.utils import ( + _monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, +) +from pydis_core.utils._extensions import unqualify + + +def apply_monkey_patches() -> None: + """ + Applies all common monkey patches for our bots. + + Patches :obj:`discord.ext.commands.Command` and :obj:`discord.ext.commands.Group` to support root aliases. + A ``root_aliases`` keyword argument is added to these two objects, which is a sequence of alias names + that will act as top-level groups rather than being aliases of the command's group. + + It's stored as an attribute also named ``root_aliases`` + + Patches discord's internal ``send_typing`` method so that it ignores 403 errors from Discord. + When under heavy load Discord has added a CloudFlare worker to this route, which causes 403 errors to be thrown. + """ + _monkey_patches._apply_monkey_patches() + + +__all__ = [ + apply_monkey_patches, + caching, + channel, + commands, + cooldown, + function, + interactions, + logging, + members, + regex, + scheduling, + unqualify, +] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/pydis_core/utils/_extensions.py b/pydis_core/utils/_extensions.py new file mode 100644 index 00000000..536a0715 --- /dev/null +++ b/pydis_core/utils/_extensions.py @@ -0,0 +1,57 @@ +"""Utilities for loading Discord extensions.""" + +import importlib +import inspect +import pkgutil +import types +from typing import NoReturn + + +def unqualify(name: str) -> str: + """ + Return an unqualified name given a qualified module/package ``name``. + + Args: + name: The module name to unqualify. + + Returns: + The unqualified module name. + """ + return name.rsplit(".", maxsplit=1)[-1] + + +def ignore_module(module: pkgutil.ModuleInfo) -> bool: + """Return whether the module with name `name` should be ignored.""" + return any(name.startswith("_") for name in module.name.split(".")) + + +def walk_extensions(module: types.ModuleType) -> frozenset[str]: + """ + Return all extension names from the given module. + + Args: + module (types.ModuleType): The module to look for extensions in. + + Returns: + A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. + """ + + def on_error(name: str) -> NoReturn: + raise ImportError(name=name) # pragma: no cover + + modules = set() + + for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): + if ignore_module(module_info): + # Ignore modules/packages that have a name starting with an underscore anywhere in their trees. + continue + + if module_info.ispkg: + imported = importlib.import_module(module_info.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + modules.add(module_info.name) + + return frozenset(modules) diff --git a/pydis_core/utils/_monkey_patches.py b/pydis_core/utils/_monkey_patches.py new file mode 100644 index 00000000..f0a8dc9c --- /dev/null +++ b/pydis_core/utils/_monkey_patches.py @@ -0,0 +1,73 @@ +"""Contains all common monkey patches, used to alter discord to fit our needs.""" + +import logging +import typing +from datetime import datetime, timedelta +from functools import partial, partialmethod + +from discord import Forbidden, http +from discord.ext import commands + +log = logging.getLogger(__name__) + + +class _Command(commands.Command): + """ + A :obj:`discord.ext.commands.Command` subclass which supports root aliases. + + A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as + top-level commands rather than being aliases of the command's group. It's stored as an attribute + also named ``root_aliases``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.root_aliases = kwargs.get("root_aliases", []) + + if not isinstance(self.root_aliases, (list, tuple)): + raise TypeError("Root aliases of a command must be a list or a tuple of strings.") + + +class _Group(commands.Group, _Command): + """ + A :obj:`discord.ext.commands.Group` subclass which supports root aliases. + + A ``root_aliases`` keyword argument is added, which is a sequence of alias names that will act as + top-level groups rather than being aliases of the command's group. It's stored as an attribute + also named ``root_aliases``. + """ + + +def _patch_typing() -> None: + """ + Sometimes Discord turns off typing events by throwing 403s. + + Handle those issues by patching discord's internal ``send_typing`` method so it ignores 403s in general. + """ + log.debug("Patching send_typing, which should fix things breaking when Discord disables typing events. Stay safe!") + + original = http.HTTPClient.send_typing + last_403: typing.Optional[datetime] = None + + async def honeybadger_type(self: http.HTTPClient, channel_id: int) -> None: + nonlocal last_403 + if last_403 and (datetime.utcnow() - last_403) < timedelta(minutes=5): + log.warning("Not sending typing event, we got a 403 less than 5 minutes ago.") + return + try: + await original(self, channel_id) + except Forbidden: + last_403 = datetime.utcnow() + log.warning("Got a 403 from typing event!") + + http.HTTPClient.send_typing = honeybadger_type + + +def _apply_monkey_patches() -> None: + """This is surfaced directly in pydis_core.utils.apply_monkey_patches().""" + commands.command = partial(commands.command, cls=_Command) + commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=_Command) + + commands.group = partial(commands.group, cls=_Group) + commands.GroupMixin.group = partialmethod(commands.GroupMixin.group, cls=_Group) + _patch_typing() diff --git a/pydis_core/utils/caching.py b/pydis_core/utils/caching.py new file mode 100644 index 00000000..ac34bb9b --- /dev/null +++ b/pydis_core/utils/caching.py @@ -0,0 +1,65 @@ +"""Utilities related to custom caches.""" + +import functools +import typing +from collections import OrderedDict + + +class AsyncCache: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + + def __init__(self, max_size: int = 128): + """ + Initialise a new :obj:`AsyncCache` instance. + + Args: + max_size: How many items to store in the cache. + """ + self._cache = OrderedDict() + self._max_size = max_size + + def __call__(self, arg_offset: int = 0) -> typing.Callable: + """ + Decorator for async cache. + + Args: + arg_offset: The offset for the position of the key argument. + + Returns: + A decorator to wrap the target function. + """ + + def decorator(function: typing.Callable) -> typing.Callable: + """ + Define the async cache decorator. + + Args: + function: The function to wrap. + + Returns: + The wrapped function. + """ + + @functools.wraps(function) + async def wrapper(*args) -> typing.Any: + """Decorator wrapper for the caching logic.""" + key = args[arg_offset:] + + if key not in self._cache: + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + self._cache[key] = await function(*args) + return self._cache[key] + return wrapper + return decorator + + def clear(self) -> None: + """Clear cache instance.""" + self._cache.clear() diff --git a/pydis_core/utils/channel.py b/pydis_core/utils/channel.py new file mode 100644 index 00000000..854c64fd --- /dev/null +++ b/pydis_core/utils/channel.py @@ -0,0 +1,54 @@ +"""Useful helper functions for interacting with various discord channel objects.""" + +import discord +from discord.ext.commands import Bot + +from pydis_core.utils import logging + +log = logging.get_logger(__name__) + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: + """ + Return whether the given ``channel`` in the the category with the id ``category_id``. + + Args: + channel: The channel to check. + category_id: The category to check for. + + Returns: + A bool depending on whether the channel is in the category. + """ + return getattr(channel, "category_id", None) == category_id + + +async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: + """ + Attempt to get or fetch the given ``channel_id`` from the bots cache, and return it. + + Args: + bot: The :obj:`discord.ext.commands.Bot` instance to use for getting/fetching. + channel_id: The channel to get/fetch. + + Raises: + :exc:`discord.InvalidData` + An unknown channel type was received from Discord. + :exc:`discord.HTTPException` + Retrieving the channel failed. + :exc:`discord.NotFound` + Invalid Channel ID. + :exc:`discord.Forbidden` + You do not have permission to fetch this channel. + + Returns: + The channel from the ID. + """ + log.trace(f"Getting the channel {channel_id}.") + + channel = bot.get_channel(channel_id) + if not channel: + log.debug(f"Channel {channel_id} is not in cache; fetching from API.") + channel = await bot.fetch_channel(channel_id) + + log.trace(f"Channel #{channel} ({channel_id}) retrieved.") + return channel diff --git a/pydis_core/utils/commands.py b/pydis_core/utils/commands.py new file mode 100644 index 00000000..7afd8137 --- /dev/null +++ b/pydis_core/utils/commands.py @@ -0,0 +1,38 @@ +from typing import Optional + +from discord import Message +from discord.ext.commands import BadArgument, Context, clean_content + + +async def clean_text_or_reply(ctx: Context, text: Optional[str] = None) -> str: + """ + Cleans a text argument or replied message's content. + + Args: + ctx: The command's context + text: The provided text argument of the command (if given) + + Raises: + :exc:`discord.ext.commands.BadArgument` + `text` wasn't provided and there's no reply message / reply message content. + + Returns: + The cleaned version of `text`, if given, else replied message. + """ + clean_content_converter = clean_content(fix_channel_mentions=True) + + if text: + return await clean_content_converter.convert(ctx, text) + + if ( + (replied_message := getattr(ctx.message.reference, "resolved", None)) # message has a cached reference + and isinstance(replied_message, Message) # referenced message hasn't been deleted + ): + if not (content := ctx.message.reference.resolved.content): + # The referenced message doesn't have a content (e.g. embed/image), so raise error + raise BadArgument("The referenced message doesn't have a text content.") + + return await clean_content_converter.convert(ctx, content) + + # No text provided, and either no message was referenced or we can't access the content + raise BadArgument("Couldn't find text to clean. Provide a string or reply to a message to use its content.") diff --git a/pydis_core/utils/cooldown.py b/pydis_core/utils/cooldown.py new file mode 100644 index 00000000..5129befd --- /dev/null +++ b/pydis_core/utils/cooldown.py @@ -0,0 +1,220 @@ +"""Helpers for setting a cooldown on commands.""" + +from __future__ import annotations + +import asyncio +import random +import time +import typing +import weakref +from collections.abc import Awaitable, Callable, Hashable, Iterable +from contextlib import suppress +from dataclasses import dataclass + +import discord +from discord.ext.commands import CommandError, Context + +from pydis_core.utils import scheduling +from pydis_core.utils.function import command_wraps + +__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"] + +_KEYWORD_SEP_SENTINEL = object() + +_ArgsList = list[object] +_HashableArgsTuple = tuple[Hashable, ...] + +if typing.TYPE_CHECKING: + import typing_extensions + from pydis_core import BotBase + +P = typing.ParamSpec("P") +"""The command's signature.""" +R = typing.TypeVar("R") +"""The command's return value.""" + + +class CommandOnCooldown(CommandError, typing.Generic[P, R]): + """Raised when a command is invoked while on cooldown.""" + + def __init__( + self, + message: str | None, + function: Callable[P, Awaitable[R]], + /, + *args: P.args, + **kwargs: P.kwargs, + ): + super().__init__(message, function, args, kwargs) + self._function = function + self._args = args + self._kwargs = kwargs + + async def call_without_cooldown(self) -> R: + """ + Run the command this cooldown blocked. + + Returns: + The command's return value. + """ + return await self._function(*self._args, **self._kwargs) + + +@dataclass +class _CooldownItem: + non_hashable_arguments: _ArgsList + timeout_timestamp: float + + +@dataclass +class _SeparatedArguments: + """Arguments separated into their hashable and non-hashable parts.""" + + hashable: _HashableArgsTuple + non_hashable: _ArgsList + + @classmethod + def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self: + """Create a new instance from full call arguments.""" + hashable = list[Hashable]() + non_hashable = list[object]() + + for item in call_arguments: + try: + hash(item) + except TypeError: + non_hashable.append(item) + else: + hashable.append(item) + + return cls(tuple(hashable), non_hashable) + + +class _CommandCooldownManager: + """ + Manage invocation cooldowns for a command through the arguments the command is called with. + + Use `set_cooldown` to set a cooldown, + and `is_on_cooldown` to check for a cooldown for a channel with the given arguments. + A cooldown lasts for `cooldown_duration` seconds. + """ + + def __init__(self, *, cooldown_duration: float): + self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]() + self._cooldown_duration = cooldown_duration + self.cleanup_task = scheduling.create_task( + self._periodical_cleanup(random.uniform(0, 10)), + name="CooldownManager cleanup", + ) + weakref.finalize(self, self.cleanup_task.cancel) + + def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None: + """Set `call_arguments` arguments on cooldown in `channel`.""" + timeout_timestamp = time.monotonic() + self._cooldown_duration + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.setdefault( + (channel, separated_arguments.hashable), + [], + ) + + for item in cooldowns_list: + if item.non_hashable_arguments == separated_arguments.non_hashable: + item.timeout_timestamp = timeout_timestamp + return + + cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp)) + + def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool: + """Check whether `call_arguments` is on cooldown in `channel`.""" + current_time = time.monotonic() + separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments) + cooldowns_list = self._cooldowns.get( + (channel, separated_arguments.hashable), + [], + ) + + for item in cooldowns_list: + if item.non_hashable_arguments == separated_arguments.non_hashable: + return item.timeout_timestamp > current_time + return False + + async def _periodical_cleanup(self, initial_delay: float) -> None: + """ + Delete stale items every hour after waiting for `initial_delay`. + + The `initial_delay` ensures cleanups are not running for every command at the same time. + A strong reference to self is only kept while cleanup is running. + """ + weak_self = weakref.ref(self) + del self + + await asyncio.sleep(initial_delay) + while True: + await asyncio.sleep(60 * 60) + weak_self()._delete_stale_items() + + def _delete_stale_items(self) -> None: + """Remove expired items from internal collections.""" + current_time = time.monotonic() + + for key, cooldowns_list in self._cooldowns.copy().items(): + filtered_cooldowns = [ + cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time + ] + + if not filtered_cooldowns: + del self._cooldowns[key] + else: + self._cooldowns[key] = filtered_cooldowns + + +def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]: + return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items()) + + +def block_duplicate_invocations( + *, + cooldown_duration: float = 5, + send_notice: bool = False, + args_preprocessor: Callable[P, Iterable[object]] | None = None, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """ + Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds. + + Args: + cooldown_duration: Length of the cooldown in seconds. + send_notice: If :obj:`True`, notify the user about the cooldown with a reply. + args_preprocessor: If specified, this function is called with the args and kwargs the function is called with, + its return value is then used to check for the cooldown instead of the raw arguments. + + Returns: + A decorator that adds a wrapper which applies the cooldowns. + + Warning: + The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown. + """ + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration) + + @command_wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if args_preprocessor is not None: + all_args = args_preprocessor(*args, **kwargs) + else: + all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command + ctx = typing.cast("Context[BotBase]", args[1]) + + if not isinstance(ctx.channel, discord.DMChannel): + if mgr.is_on_cooldown(ctx.channel, all_args): + if send_notice: + with suppress(discord.NotFound): + await ctx.reply("The command is on cooldown with the given arguments.") + raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs) + mgr.set_cooldown(ctx.channel, all_args) + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/pydis_core/utils/function.py b/pydis_core/utils/function.py new file mode 100644 index 00000000..d89163ec --- /dev/null +++ b/pydis_core/utils/function.py @@ -0,0 +1,111 @@ +"""Utils for manipulating functions.""" + +from __future__ import annotations + +import functools +import types +import typing +from collections.abc import Callable, Sequence, Set + +__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"] + + +if typing.TYPE_CHECKING: + _P = typing.ParamSpec("_P") + _R = typing.TypeVar("_R") + + +class GlobalNameConflictError(Exception): + """Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper.""" + + +def update_wrapper_globals( + wrapper: Callable[_P, _R], + wrapped: Callable[_P, _R], + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[_P, _R]: + r""" + Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals. + + For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function + to resolve their values. This breaks for decorators that replace the function because they have + their own globals. + + .. warning:: + This function captures the state of ``wrapped``\'s module's globals when it's called; + changes won't be reflected in the new function's globals. + + Args: + wrapper: The function to wrap. + wrapped: The function to wrap with. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Raises: + :exc:`GlobalNameConflictError`: + If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints, + and is not in ``ignored_conflict_names``. + """ + wrapped = typing.cast(types.FunctionType, wrapped) + wrapper = typing.cast(types.FunctionType, wrapper) + + annotation_global_names = ( + ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) + ) + # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. + shared_globals = ( + set(wrapper.__code__.co_names) + & set(annotation_global_names) + & set(wrapped.__globals__) + & set(wrapper.__globals__) + - ignored_conflict_names + ) + if shared_globals: + raise GlobalNameConflictError( + f"wrapper and the wrapped function share the following " + f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " + f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." + ) + + new_globals = wrapper.__globals__.copy() + new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) + return types.FunctionType( + code=wrapper.__code__, + globals=new_globals, + name=wrapper.__name__, + argdefs=wrapper.__defaults__, + closure=wrapper.__closure__, + ) + + +def command_wraps( + wrapped: Callable[_P, _R], + assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS, + updated: Sequence[str] = functools.WRAPPER_UPDATES, + *, + ignored_conflict_names: Set[str] = frozenset(), +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + r""" + Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation. + + See :func:`update_wrapper_globals` for more details on how the globals are updated. + + Args: + wrapped: The function to wrap with. + assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``. + updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``. + ignored_conflict_names: A set of names to ignore if a conflict between them is found. + + Returns: + A decorator that behaves like :func:`functools.wraps`, + with the wrapper replaced with the function :func:`update_wrapper_globals` returned. + """ # noqa: D200 + def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]: + return functools.update_wrapper( + update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), + wrapped, + assigned, + updated, + ) + + return decorator diff --git a/pydis_core/utils/interactions.py b/pydis_core/utils/interactions.py new file mode 100644 index 00000000..3e4acffe --- /dev/null +++ b/pydis_core/utils/interactions.py @@ -0,0 +1,98 @@ +import contextlib +from typing import Optional, Sequence + +from discord import ButtonStyle, Interaction, Message, NotFound, ui + +from pydis_core.utils.logging import get_logger + +log = get_logger(__name__) + + +class ViewWithUserAndRoleCheck(ui.View): + """ + A view that allows the original invoker and moderators to interact with it. + + Args: + allowed_users: A sequence of user's ids who are allowed to interact with the view. + allowed_roles: A sequence of role ids that are allowed to interact with the view. + timeout: Timeout in seconds from last interaction with the UI before no longer accepting input. + If ``None`` then there is no timeout. + message: The message to remove the view from on timeout. This can also be set with + ``view.message = await ctx.send( ... )``` , or similar, after the view is instantiated. + """ + + def __init__( + self, + *, + allowed_users: Sequence[int], + allowed_roles: Sequence[int], + timeout: Optional[float] = 180.0, + message: Optional[Message] = None + ) -> None: + super().__init__(timeout=timeout) + self.allowed_users = allowed_users + self.allowed_roles = allowed_roles + self.message = message + + async def interaction_check(self, interaction: Interaction) -> bool: + """ + Ensure the user clicking the button is the view invoker, or a moderator. + + Args: + interaction: The interaction that occurred. + """ + if interaction.user.id in self.allowed_users: + log.trace( + "Allowed interaction by %s (%d) on %d as they are an allowed user.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): + log.trace( + "Allowed interaction by %s (%d)on %d as they have an allowed role.", + interaction.user, + interaction.user.id, + interaction.message.id, + ) + return True + + await interaction.response.send_message("This is not your button to click!", ephemeral=True) + return False + + async def on_timeout(self) -> None: + """Remove the view from ``self.message`` if set.""" + if self.message: + with contextlib.suppress(NotFound): + # Cover the case where this message has already been deleted by external means + await self.message.edit(view=None) + + +class DeleteMessageButton(ui.Button): + """ + A button that can be added to a view to delete the message containing the view on click. + + This button itself carries out no interaction checks, these should be done by the parent view. + + See :obj:`pydis_core.utils.interactions.ViewWithUserAndRoleCheck` for a view that implements basic checks. + + Args: + style (:literal-url:`ButtonStyle `): + The style of the button, set to ``ButtonStyle.secondary`` if not specified. + label: The label of the button, set to "Delete" if not specified. + """ # noqa: E501 + + def __init__( + self, + *, + style: ButtonStyle = ButtonStyle.secondary, + label: str = "Delete", + **kwargs + ): + super().__init__(style=style, label=label, **kwargs) + + async def callback(self, interaction: Interaction) -> None: + """Delete the original message on button click.""" + await interaction.message.delete() diff --git a/pydis_core/utils/logging.py b/pydis_core/utils/logging.py new file mode 100644 index 00000000..7814f348 --- /dev/null +++ b/pydis_core/utils/logging.py @@ -0,0 +1,51 @@ +"""Common logging related functions.""" + +import logging +import typing + +if typing.TYPE_CHECKING: + LoggerClass = logging.Logger +else: + LoggerClass = logging.getLoggerClass() + +TRACE_LEVEL = 5 + + +class CustomLogger(LoggerClass): + """Custom implementation of the :obj:`logging.Logger` class with an added :obj:`trace` method.""" + + def trace(self, msg: str, *args, **kwargs) -> None: + """ + Log the given message with the severity ``"TRACE"``. + + To pass exception information, use the keyword argument exc_info with a true value: + + .. code-block:: py + + logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) + + Args: + msg: The message to be logged. + args, kwargs: Passed to the base log function as is. + """ + if self.isEnabledFor(TRACE_LEVEL): + self.log(TRACE_LEVEL, msg, *args, **kwargs) + + +def get_logger(name: typing.Optional[str] = None) -> CustomLogger: + """ + Utility to make mypy recognise that logger is of type :obj:`CustomLogger`. + + Args: + name: The name given to the logger. + + Returns: + An instance of the :obj:`CustomLogger` class. + """ + return typing.cast(CustomLogger, logging.getLogger(name)) + + +# Setup trace level logging so that we can use it within pydis_core. +logging.TRACE = TRACE_LEVEL +logging.setLoggerClass(CustomLogger) +logging.addLevelName(TRACE_LEVEL, "TRACE") diff --git a/pydis_core/utils/members.py b/pydis_core/utils/members.py new file mode 100644 index 00000000..b6eacc88 --- /dev/null +++ b/pydis_core/utils/members.py @@ -0,0 +1,57 @@ +"""Useful helper functions for interactin with :obj:`discord.Member` objects.""" +import typing +from collections import abc + +import discord + +from pydis_core.utils import logging + +log = logging.get_logger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: + """ + Attempt to get a member from cache; on failure fetch from the API. + + Returns: + The :obj:`discord.Member` or :obj:`None` to indicate the member could not be found. + """ + if member := guild.get_member(member_id): + log.trace(f"{member} retrieved from cache.") + else: + try: + member = await guild.fetch_member(member_id) + except discord.errors.NotFound: + log.trace(f"Failed to fetch {member_id} from API.") + return None + log.trace(f"{member} fetched from API.") + return member + + +async def handle_role_change( + member: discord.Member, + coro: typing.Callable[[discord.Role], abc.Coroutine], + role: discord.Role +) -> None: + """ + Await the given ``coro`` with ``role`` as the sole argument. + + Handle errors that we expect to be raised from + :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`. + + Args: + member: The member that is being modified for logging purposes. + coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`. + role: The role to be passed to ``coro``. + """ + try: + await coro(role) + except discord.NotFound: + log.error(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.error( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/pydis_core/utils/regex.py b/pydis_core/utils/regex.py new file mode 100644 index 00000000..de82a1ed --- /dev/null +++ b/pydis_core/utils/regex.py @@ -0,0 +1,54 @@ +"""Common regular expressions.""" + +import re + +DISCORD_INVITE = re.compile( + r"(https?://)?(www\.)?" # Optional http(s) and www. + r"(discord([.,]|dot)gg|" # Could be discord.gg/ + r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ + r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ + r"discord([.,]|dot)me|" # or discord.me + r"discord([.,]|dot)li|" # or discord.li + r"discord([.,]|dot)io|" # or discord.io. + r"((?\S+)", # the invite code itself + flags=re.IGNORECASE +) +""" +Regex for Discord server invites. + +.. warning:: + This regex pattern will capture until a whitespace, if you are to use the 'invite' capture group in + any HTTP requests or similar. Please ensure you sanitise the output using something + such as :func:`urllib.parse.quote`. + +:meta hide-value: +""" + +FORMATTED_CODE_REGEX = re.compile( + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)", # match the exact same delimiter from the start again + flags=re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +""" +Regex for formatted code, using Discord's code blocks. + +:meta hide-value: +""" + +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + flags=re.DOTALL # "." also matches newlines +) +""" +Regex for raw code, *not* using Discord's code blocks. + +:meta hide-value: +""" diff --git a/pydis_core/utils/scheduling.py b/pydis_core/utils/scheduling.py new file mode 100644 index 00000000..eced4a3d --- /dev/null +++ b/pydis_core/utils/scheduling.py @@ -0,0 +1,252 @@ +"""Generic python scheduler.""" + +import asyncio +import contextlib +import inspect +import typing +from collections import abc +from datetime import datetime +from functools import partial + +from pydis_core.utils import logging + + +class Scheduler: + """ + Schedule the execution of coroutines and keep track of them. + + When instantiating a :obj:`Scheduler`, a name must be provided. This name is used to distinguish the + instance's log messages from other instances. Using the name of the class or module containing + the instance is suggested. + + Coroutines can be scheduled immediately with :obj:`schedule` or in the future with :obj:`schedule_at` + or :obj:`schedule_later`. A unique ID is required to be given in order to keep track of the + resulting Tasks. Any scheduled task can be cancelled prematurely using :obj:`cancel` by providing + the same ID used to schedule it. + + The ``in`` operator is supported for checking if a task with a given ID is currently scheduled. + + Any exception raised in a scheduled task is logged when the task is done. + """ + + def __init__(self, name: str): + """ + Initialize a new :obj:`Scheduler` instance. + + Args: + name: The name of the :obj:`Scheduler`. Used in logging, and namespacing. + """ + self.name = name + + self._log = logging.get_logger(f"{__name__}.{name}") + self._scheduled_tasks: dict[abc.Hashable, asyncio.Task] = {} + + def __contains__(self, task_id: abc.Hashable) -> bool: + """ + Return :obj:`True` if a task with the given ``task_id`` is currently scheduled. + + Args: + task_id: The task to look for. + + Returns: + :obj:`True` if the task was found. + """ + return task_id in self._scheduled_tasks + + def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: + """ + Schedule the execution of a ``coroutine``. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self._log.trace(f"Scheduling task #{task_id}...") + + msg = f"Cannot schedule an already started coroutine for #{task_id}" + assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg + + if task_id in self._scheduled_tasks: + self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") + coroutine.close() + return + + task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") + task.add_done_callback(partial(self._task_done_callback, task_id)) + + self._scheduled_tasks[task_id] = task + self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + + def schedule_at(self, time: datetime, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None: + """ + Schedule ``coroutine`` to be executed at the given ``time``. + + If ``time`` is timezone aware, then use that timezone to calculate now() when subtracting. + If ``time`` is naïve, then use UTC. + + If ``time`` is in the past, schedule ``coroutine`` immediately. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + time: The time to start the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() + delay = (time - now_datetime).total_seconds() + if delay > 0: + coroutine = self._await_later(delay, task_id, coroutine) + + self.schedule(task_id, coroutine) + + def schedule_later( + self, + delay: typing.Union[int, float], + task_id: abc.Hashable, + coroutine: abc.Coroutine + ) -> None: + """ + Schedule ``coroutine`` to be executed after ``delay`` seconds. + + If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + delay: How long to wait before starting the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self.schedule(task_id, self._await_later(delay, task_id, coroutine)) + + def cancel(self, task_id: abc.Hashable) -> None: + """ + Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist. + + Args: + task_id: The task's unique ID. + """ + self._log.trace(f"Cancelling task #{task_id}...") + + try: + task = self._scheduled_tasks.pop(task_id) + except KeyError: + self._log.warning(f"Failed to unschedule {task_id} (no task found).") + else: + task.cancel() + + self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") + + def cancel_all(self) -> None: + """Unschedule all known tasks.""" + self._log.debug("Unscheduling all tasks") + + for task_id in self._scheduled_tasks.copy(): + self.cancel(task_id) + + async def _await_later( + self, + delay: typing.Union[int, float], + task_id: abc.Hashable, + coroutine: abc.Coroutine + ) -> None: + """Await ``coroutine`` after ``delay`` seconds.""" + try: + self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") + await asyncio.sleep(delay) + + # Use asyncio.shield to prevent the coroutine from cancelling itself. + self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") + await asyncio.shield(coroutine) + finally: + # Close it to prevent unawaited coroutine warnings, + # which would happen if the task was cancelled during the sleep. + # Only close it if it's not been awaited yet. This check is important because the + # coroutine may cancel this task, which would also trigger the finally block. + state = inspect.getcoroutinestate(coroutine) + if state == "CORO_CREATED": + self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") + coroutine.close() + else: + self._log.debug(f"Finally block reached for #{task_id}; {state=}") + + def _task_done_callback(self, task_id: abc.Hashable, done_task: asyncio.Task) -> None: + """ + Delete the task and raise its exception if one exists. + + If ``done_task`` and the task associated with ``task_id`` are different, then the latter + will not be deleted. In this case, a new task was likely rescheduled with the same ID. + """ + self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") + + scheduled_task = self._scheduled_tasks.get(task_id) + + if scheduled_task and done_task is scheduled_task: + # A task for the ID exists and is the same as the done task. + # Since this is the done callback, the task is already done so no need to cancel it. + self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") + del self._scheduled_tasks[task_id] + elif scheduled_task: + # A new task was likely rescheduled with the same ID. + self._log.debug( + f"The scheduled task #{task_id} {id(scheduled_task)} " + f"and the done task {id(done_task)} differ." + ) + elif not done_task.cancelled(): + self._log.warning( + f"Task #{task_id} not found while handling task {id(done_task)}! " + f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." + ) + + with contextlib.suppress(asyncio.CancelledError): + exception = done_task.exception() + # Log the exception if one exists. + if exception: + self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) + + +TASK_RETURN = typing.TypeVar("TASK_RETURN") + + +def create_task( + coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN], + *, + suppressed_exceptions: tuple[type[Exception], ...] = (), + event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, + **kwargs, +) -> asyncio.Task[TASK_RETURN]: + """ + Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task. + + If the ``event_loop`` kwarg is provided, the task is created from that event loop, + otherwise the running loop is used. + + Args: + coro: The function to call. + suppressed_exceptions: Exceptions to be handled by the task. + event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. + kwargs: Passed to :py:func:`asyncio.create_task`. + + Returns: + asyncio.Task: The wrapped task. + """ + if event_loop is not None: + task = event_loop.create_task(coro, **kwargs) + else: + task = asyncio.create_task(coro, **kwargs) + task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) + return task + + +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None: + """Retrieve and log the exception raised in ``task`` if one exists.""" + with contextlib.suppress(asyncio.CancelledError): + exception = task.exception() + # Log the exception if one exists. + if exception and not isinstance(exception, suppressed_exceptions): + log = logging.get_logger(__name__) + log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/pyproject.toml b/pyproject.toml index d7a4bfb5..b742991d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] -name = "bot-core" -version = "8.2.1" -description = "PyDis bot core provides core functionality and utility to the bots of the Python Discord community." +name = "pydis_core" +version = "9.0.0" +description = "PyDis core provides core functionality and utility to the bots of the Python Discord community." authors = ["Python Discord "] license = "MIT" classifiers=[ @@ -12,11 +12,12 @@ classifiers=[ "Programming Language :: Python :: 3", ] packages = [ - { include = "botcore" }, + { include = "pydis_core" }, ] include = ["LICENSE"] exclude = ["tests", "tests.*"] readme = "README.md" +homepage = "https://pythondiscord.com/" documentation = "https://bot-core.pythondiscord.com/" repository = "https://github.com/python-discord/bot-core" keywords = ["bot", "discord", "discord.py"] @@ -68,5 +69,5 @@ build-backend = "poetry.core.masonry.api" [tool.coverage.run] branch = true -source_pkgs = ["botcore"] +source_pkgs = ["pydis_core"] source = ["tests"] diff --git a/tests/botcore/test_api.py b/tests/botcore/test_api.py deleted file mode 100644 index 86c9e5f3..00000000 --- a/tests/botcore/test_api.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -from unittest.mock import MagicMock - -from botcore import site_api - - -class APIClientTests(unittest.IsolatedAsyncioTestCase): - """Tests for botcore's site API client.""" - - @classmethod - def setUpClass(cls): - """Sets up the shared fixtures for the tests.""" - cls.error_api_response = MagicMock() - cls.error_api_response.status = 999 - - def test_response_code_error_default_initialization(self): - """Test the default initialization of `ResponseCodeError` without `text` or `json`""" - error = site_api.ResponseCodeError(response=self.error_api_response) - - self.assertIs(error.status, self.error_api_response.status) - self.assertEqual(error.response_json, {}) - self.assertEqual(error.response_text, None) - self.assertIs(error.response, self.error_api_response) - - def test_response_code_error_string_representation_default_initialization(self): - """Test the string representation of `ResponseCodeError` initialized without text or json.""" - error = site_api.ResponseCodeError(response=self.error_api_response) - self.assertEqual( - str(error), - f"Status: {self.error_api_response.status} Response: {None}" - ) - - def test_response_code_error_initialization_with_json(self): - """Test the initialization of `ResponseCodeError` with json.""" - json_data = {'hello': 'world'} - error = site_api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data, - ) - self.assertEqual(error.response_json, json_data) - self.assertEqual(error.response_text, None) - - def test_response_code_error_string_representation_with_nonempty_response_json(self): - """Test the string representation of `ResponseCodeError` initialized with json.""" - json_data = {'hello': 'world'} - error = site_api.ResponseCodeError( - response=self.error_api_response, - response_json=json_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") - - def test_response_code_error_initialization_with_text(self): - """Test the initialization of `ResponseCodeError` with text.""" - text_data = 'Lemon will eat your soul' - error = site_api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data, - ) - self.assertEqual(error.response_text, text_data) - self.assertEqual(error.response_json, {}) - - def test_response_code_error_string_representation_with_nonempty_response_text(self): - """Test the string representation of `ResponseCodeError` initialized with text.""" - text_data = 'Lemon will eat your soul' - error = site_api.ResponseCodeError( - response=self.error_api_response, - response_text=text_data - ) - self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/botcore/utils/test_cooldown.py b/tests/botcore/utils/test_cooldown.py deleted file mode 100644 index 00e5a052..00000000 --- a/tests/botcore/utils/test_cooldown.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest -from unittest.mock import patch - -from botcore.utils.cooldown import _CommandCooldownManager, _create_argument_tuple - - -class CommandCooldownManagerTests(unittest.IsolatedAsyncioTestCase): - test_call_args = ( - _create_argument_tuple(0), - _create_argument_tuple(a=0), - _create_argument_tuple([]), - _create_argument_tuple(a=[]), - _create_argument_tuple(1, 2, 3, a=4, b=5, c=6), - _create_argument_tuple([1], [2], [3], a=[4], b=[5], c=[6]), - _create_argument_tuple([1], 2, [3], a=4, b=[5], c=6), - ) - - async def asyncSetUp(self): - self.cooldown_manager = _CommandCooldownManager(cooldown_duration=5) - - def test_no_cooldown_on_unset(self): - for call_args in self.test_call_args: - with self.subTest(arguments_tuple=call_args, channel=0): - self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args)) - - for call_args in self.test_call_args: - with self.subTest(arguments_tuple=call_args, channel=1): - self.assertFalse(self.cooldown_manager.is_on_cooldown(1, call_args)) - - @patch("time.monotonic") - def test_cooldown_is_set(self, monotonic): - monotonic.side_effect = lambda: 0 - for call_args in self.test_call_args: - with self.subTest(arguments_tuple=call_args): - self.cooldown_manager.set_cooldown(0, call_args) - self.assertTrue(self.cooldown_manager.is_on_cooldown(0, call_args)) - - @patch("time.monotonic") - def test_cooldown_expires(self, monotonic): - for call_args in self.test_call_args: - monotonic.side_effect = (0, 1000) - with self.subTest(arguments_tuple=call_args): - self.cooldown_manager.set_cooldown(0, call_args) - self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args)) - - def test_keywords_and_tuples_differentiated(self): - self.cooldown_manager.set_cooldown(0, _create_argument_tuple(("a", 0))) - self.assertFalse(self.cooldown_manager.is_on_cooldown(0, _create_argument_tuple(a=0))) - self.assertTrue(self.cooldown_manager.is_on_cooldown(0, _create_argument_tuple(("a", 0)))) diff --git a/tests/botcore/utils/test_regex.py b/tests/botcore/utils/test_regex.py deleted file mode 100644 index 491e22bd..00000000 --- a/tests/botcore/utils/test_regex.py +++ /dev/null @@ -1,65 +0,0 @@ -import unittest -from typing import Optional - -from botcore.utils.regex import DISCORD_INVITE - - -def match_regex(s: str) -> Optional[str]: - """Helper function to run re.match on a string. - - Return the invite capture group, if the string matches the pattern - else return None - """ - result = DISCORD_INVITE.match(s) - return result if result is None else result.group("invite") - - -def search_regex(s: str) -> Optional[str]: - """Helper function to run re.search on a string. - - Return the invite capture group, if the string matches the pattern - else return None - """ - result = DISCORD_INVITE.search(s) - return result if result is None else result.group("invite") - - -class UtilsRegexTests(unittest.TestCase): - - def test_discord_invite_positives(self): - """Test the DISCORD_INVITE regex on a set of strings we would expect to capture.""" - - self.assertEqual(match_regex("discord.gg/python"), "python") - self.assertEqual(match_regex("https://discord.gg/python"), "python") - self.assertEqual(match_regex("https://www.discord.gg/python"), "python") - self.assertEqual(match_regex("discord.com/invite/python"), "python") - self.assertEqual(match_regex("www.discord.com/invite/python"), "python") - self.assertEqual(match_regex("discordapp.com/invite/python"), "python") - self.assertEqual(match_regex("discord.me/python"), "python") - self.assertEqual(match_regex("discord.li/python"), "python") - self.assertEqual(match_regex("discord.io/python"), "python") - self.assertEqual(match_regex(".gg/python"), "python") - - self.assertEqual(match_regex("discord.gg/python/but/extra"), "python/but/extra") - self.assertEqual(match_regex("discord.me/this/isnt/python"), "this/isnt/python") - self.assertEqual(match_regex(".gg/a/a/a/a/a/a/a/a/a/a/a"), "a/a/a/a/a/a/a/a/a/a/a") - self.assertEqual(match_regex("discordapp.com/invite/python/snakescord"), "python/snakescord") - self.assertEqual(match_regex("http://discord.gg/python/%20/notpython"), "python/%20/notpython") - self.assertEqual(match_regex("discord.gg/python?=ts/notpython"), "python?=ts/notpython") - self.assertEqual(match_regex("https://discord.gg/python#fragment/notpython"), "python#fragment/notpython") - self.assertEqual(match_regex("https://discord.gg/python/~/notpython"), "python/~/notpython") - - self.assertEqual(search_regex("https://discord.gg/python with whitespace"), "python") - self.assertEqual(search_regex(" https://discord.gg/python "), "python") - - def test_discord_invite_negatives(self): - """Test the DISCORD_INVITE regex on a set of strings we would expect to not capture.""" - - self.assertEqual(match_regex("another string"), None) - self.assertEqual(match_regex("https://pythondiscord.com"), None) - self.assertEqual(match_regex("https://discord.com"), None) - self.assertEqual(match_regex("https://discord.gg"), None) - self.assertEqual(match_regex("https://discord.gg/ python"), None) - - self.assertEqual(search_regex("https://discord.com with whitespace"), None) - self.assertEqual(search_regex(" https://discord.com "), None) diff --git a/tests/pydis_core/test_api.py b/tests/pydis_core/test_api.py new file mode 100644 index 00000000..92444e19 --- /dev/null +++ b/tests/pydis_core/test_api.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import MagicMock + +from pydis_core import site_api + + +class APIClientTests(unittest.IsolatedAsyncioTestCase): + """Tests for pydis_core's site API client.""" + + @classmethod + def setUpClass(cls): + """Sets up the shared fixtures for the tests.""" + cls.error_api_response = MagicMock() + cls.error_api_response.status = 999 + + def test_response_code_error_default_initialization(self): + """Test the default initialization of `ResponseCodeError` without `text` or `json`""" + error = site_api.ResponseCodeError(response=self.error_api_response) + + self.assertIs(error.status, self.error_api_response.status) + self.assertEqual(error.response_json, {}) + self.assertEqual(error.response_text, None) + self.assertIs(error.response, self.error_api_response) + + def test_response_code_error_string_representation_default_initialization(self): + """Test the string representation of `ResponseCodeError` initialized without text or json.""" + error = site_api.ResponseCodeError(response=self.error_api_response) + self.assertEqual( + str(error), + f"Status: {self.error_api_response.status} Response: {None}" + ) + + def test_response_code_error_initialization_with_json(self): + """Test the initialization of `ResponseCodeError` with json.""" + json_data = {'hello': 'world'} + error = site_api.ResponseCodeError( + response=self.error_api_response, + response_json=json_data, + ) + self.assertEqual(error.response_json, json_data) + self.assertEqual(error.response_text, None) + + def test_response_code_error_string_representation_with_nonempty_response_json(self): + """Test the string representation of `ResponseCodeError` initialized with json.""" + json_data = {'hello': 'world'} + error = site_api.ResponseCodeError( + response=self.error_api_response, + response_json=json_data + ) + self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") + + def test_response_code_error_initialization_with_text(self): + """Test the initialization of `ResponseCodeError` with text.""" + text_data = 'Lemon will eat your soul' + error = site_api.ResponseCodeError( + response=self.error_api_response, + response_text=text_data, + ) + self.assertEqual(error.response_text, text_data) + self.assertEqual(error.response_json, {}) + + def test_response_code_error_string_representation_with_nonempty_response_text(self): + """Test the string representation of `ResponseCodeError` initialized with text.""" + text_data = 'Lemon will eat your soul' + error = site_api.ResponseCodeError( + response=self.error_api_response, + response_text=text_data + ) + self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") diff --git a/tests/pydis_core/utils/test_cooldown.py b/tests/pydis_core/utils/test_cooldown.py new file mode 100644 index 00000000..eed16da3 --- /dev/null +++ b/tests/pydis_core/utils/test_cooldown.py @@ -0,0 +1,49 @@ +import unittest +from unittest.mock import patch + +from pydis_core.utils.cooldown import _CommandCooldownManager, _create_argument_tuple + + +class CommandCooldownManagerTests(unittest.IsolatedAsyncioTestCase): + test_call_args = ( + _create_argument_tuple(0), + _create_argument_tuple(a=0), + _create_argument_tuple([]), + _create_argument_tuple(a=[]), + _create_argument_tuple(1, 2, 3, a=4, b=5, c=6), + _create_argument_tuple([1], [2], [3], a=[4], b=[5], c=[6]), + _create_argument_tuple([1], 2, [3], a=4, b=[5], c=6), + ) + + async def asyncSetUp(self): + self.cooldown_manager = _CommandCooldownManager(cooldown_duration=5) + + def test_no_cooldown_on_unset(self): + for call_args in self.test_call_args: + with self.subTest(arguments_tuple=call_args, channel=0): + self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args)) + + for call_args in self.test_call_args: + with self.subTest(arguments_tuple=call_args, channel=1): + self.assertFalse(self.cooldown_manager.is_on_cooldown(1, call_args)) + + @patch("time.monotonic") + def test_cooldown_is_set(self, monotonic): + monotonic.side_effect = lambda: 0 + for call_args in self.test_call_args: + with self.subTest(arguments_tuple=call_args): + self.cooldown_manager.set_cooldown(0, call_args) + self.assertTrue(self.cooldown_manager.is_on_cooldown(0, call_args)) + + @patch("time.monotonic") + def test_cooldown_expires(self, monotonic): + for call_args in self.test_call_args: + monotonic.side_effect = (0, 1000) + with self.subTest(arguments_tuple=call_args): + self.cooldown_manager.set_cooldown(0, call_args) + self.assertFalse(self.cooldown_manager.is_on_cooldown(0, call_args)) + + def test_keywords_and_tuples_differentiated(self): + self.cooldown_manager.set_cooldown(0, _create_argument_tuple(("a", 0))) + self.assertFalse(self.cooldown_manager.is_on_cooldown(0, _create_argument_tuple(a=0))) + self.assertTrue(self.cooldown_manager.is_on_cooldown(0, _create_argument_tuple(("a", 0)))) diff --git a/tests/pydis_core/utils/test_regex.py b/tests/pydis_core/utils/test_regex.py new file mode 100644 index 00000000..01a2412b --- /dev/null +++ b/tests/pydis_core/utils/test_regex.py @@ -0,0 +1,65 @@ +import unittest +from typing import Optional + +from pydis_core.utils.regex import DISCORD_INVITE + + +def match_regex(s: str) -> Optional[str]: + """Helper function to run re.match on a string. + + Return the invite capture group, if the string matches the pattern + else return None + """ + result = DISCORD_INVITE.match(s) + return result if result is None else result.group("invite") + + +def search_regex(s: str) -> Optional[str]: + """Helper function to run re.search on a string. + + Return the invite capture group, if the string matches the pattern + else return None + """ + result = DISCORD_INVITE.search(s) + return result if result is None else result.group("invite") + + +class UtilsRegexTests(unittest.TestCase): + + def test_discord_invite_positives(self): + """Test the DISCORD_INVITE regex on a set of strings we would expect to capture.""" + + self.assertEqual(match_regex("discord.gg/python"), "python") + self.assertEqual(match_regex("https://discord.gg/python"), "python") + self.assertEqual(match_regex("https://www.discord.gg/python"), "python") + self.assertEqual(match_regex("discord.com/invite/python"), "python") + self.assertEqual(match_regex("www.discord.com/invite/python"), "python") + self.assertEqual(match_regex("discordapp.com/invite/python"), "python") + self.assertEqual(match_regex("discord.me/python"), "python") + self.assertEqual(match_regex("discord.li/python"), "python") + self.assertEqual(match_regex("discord.io/python"), "python") + self.assertEqual(match_regex(".gg/python"), "python") + + self.assertEqual(match_regex("discord.gg/python/but/extra"), "python/but/extra") + self.assertEqual(match_regex("discord.me/this/isnt/python"), "this/isnt/python") + self.assertEqual(match_regex(".gg/a/a/a/a/a/a/a/a/a/a/a"), "a/a/a/a/a/a/a/a/a/a/a") + self.assertEqual(match_regex("discordapp.com/invite/python/snakescord"), "python/snakescord") + self.assertEqual(match_regex("http://discord.gg/python/%20/notpython"), "python/%20/notpython") + self.assertEqual(match_regex("discord.gg/python?=ts/notpython"), "python?=ts/notpython") + self.assertEqual(match_regex("https://discord.gg/python#fragment/notpython"), "python#fragment/notpython") + self.assertEqual(match_regex("https://discord.gg/python/~/notpython"), "python/~/notpython") + + self.assertEqual(search_regex("https://discord.gg/python with whitespace"), "python") + self.assertEqual(search_regex(" https://discord.gg/python "), "python") + + def test_discord_invite_negatives(self): + """Test the DISCORD_INVITE regex on a set of strings we would expect to not capture.""" + + self.assertEqual(match_regex("another string"), None) + self.assertEqual(match_regex("https://pythondiscord.com"), None) + self.assertEqual(match_regex("https://discord.com"), None) + self.assertEqual(match_regex("https://discord.gg"), None) + self.assertEqual(match_regex("https://discord.gg/ python"), None) + + self.assertEqual(search_regex("https://discord.com with whitespace"), None) + self.assertEqual(search_regex(" https://discord.com "), None) diff --git a/tox.ini b/tox.ini index 717e412d..1450196c 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ max-line-length=120 docstring-convention=all import-order-style=pycharm -application_import_names=botcore,docs,tests +application_import_names=pydis_core,docs,tests exclude=.cache,.venv,.git,constants.py,bot/ ignore= B311,W503,E226,S311,T000,E731 -- cgit v1.2.3