From bcdb3a77690e1e224225627f085d86689353e1cb Mon Sep 17 00:00:00 2001 From: Hassan Abouelela Date: Mon, 21 Feb 2022 12:58:10 +0000 Subject: Port many utilities from bot --- botcore/__init__.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'botcore/__init__.py') diff --git a/botcore/__init__.py b/botcore/__init__.py index c582d0df..f32e6bf2 100644 --- a/botcore/__init__.py +++ b/botcore/__init__.py @@ -1,9 +1,16 @@ -from botcore import ( - regex, -) +"""Useful utilities and tools for discord bot development.""" + +from botcore import (caching, channel, extensions, exts, loggers, members, regex, scheduling) __all__ = [ + caching, + channel, + extensions, + exts, + loggers, + members, regex, + scheduling, ] __all__ = list(map(lambda module: module.__name__, __all__)) -- cgit v1.2.3 From 060bad105dc2569fc485adb03b985aa2ab5d367e Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 22 Feb 2022 23:47:09 +0000 Subject: Move new utilities to the util namespace --- botcore/__init__.py | 10 +- botcore/caching.py | 65 ------------ botcore/channel.py | 26 ----- botcore/extensions.py | 52 ---------- botcore/loggers.py | 45 -------- botcore/members.py | 48 --------- botcore/regex.py | 48 --------- botcore/scheduling.py | 246 -------------------------------------------- botcore/utils/__init__.py | 15 +++ botcore/utils/caching.py | 65 ++++++++++++ botcore/utils/channel.py | 26 +++++ botcore/utils/extensions.py | 52 ++++++++++ botcore/utils/loggers.py | 45 ++++++++ botcore/utils/members.py | 48 +++++++++ botcore/utils/regex.py | 48 +++++++++ botcore/utils/scheduling.py | 246 ++++++++++++++++++++++++++++++++++++++++++++ 16 files changed, 547 insertions(+), 538 deletions(-) delete mode 100644 botcore/caching.py delete mode 100644 botcore/channel.py delete mode 100644 botcore/extensions.py delete mode 100644 botcore/loggers.py delete mode 100644 botcore/members.py delete mode 100644 botcore/regex.py delete mode 100644 botcore/scheduling.py create mode 100644 botcore/utils/__init__.py create mode 100644 botcore/utils/caching.py create mode 100644 botcore/utils/channel.py create mode 100644 botcore/utils/extensions.py create mode 100644 botcore/utils/loggers.py create mode 100644 botcore/utils/members.py create mode 100644 botcore/utils/regex.py create mode 100644 botcore/utils/scheduling.py (limited to 'botcore/__init__.py') diff --git a/botcore/__init__.py b/botcore/__init__.py index f32e6bf2..d910f393 100644 --- a/botcore/__init__.py +++ b/botcore/__init__.py @@ -1,16 +1,10 @@ """Useful utilities and tools for discord bot development.""" -from botcore import (caching, channel, extensions, exts, loggers, members, regex, scheduling) +from botcore import exts, utils __all__ = [ - caching, - channel, - extensions, exts, - loggers, - members, - regex, - scheduling, + utils, ] __all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/caching.py b/botcore/caching.py deleted file mode 100644 index ea71ed1d..00000000 --- a/botcore/caching.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Utilities related to custom caches.""" - -import functools -import typing -from collections import OrderedDict - - -class AsyncCache: - """ - LRU cache implementation for coroutines. - - Once the cache exceeds the maximum size, keys are deleted in FIFO order. - - An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. - """ - - def __init__(self, max_size: int = 128): - """ - Initialise a new AsyncCache instance. - - Args: - max_size: How many items to store in the cache. - """ - self._cache = OrderedDict() - self._max_size = max_size - - def __call__(self, arg_offset: int = 0) -> typing.Callable: - """ - Decorator for async cache. - - Args: - arg_offset: The offset for the position of the key argument. - - Returns: - A decorator to wrap the target function. - """ - - def decorator(function: typing.Callable) -> typing.Callable: - """ - Define the async cache decorator. - - Args: - function: The function to wrap. - - Returns: - The wrapped function. - """ - - @functools.wraps(function) - async def wrapper(*args) -> typing.Any: - """Decorator wrapper for the caching logic.""" - key = args[arg_offset:] - - if key not in self._cache: - if len(self._cache) > self._max_size: - self._cache.popitem(last=False) - - self._cache[key] = await function(*args) - return self._cache[key] - return wrapper - return decorator - - def clear(self) -> None: - """Clear cache instance.""" - self._cache.clear() diff --git a/botcore/channel.py b/botcore/channel.py deleted file mode 100644 index b19b4f08..00000000 --- a/botcore/channel.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Utilities for interacting with discord channels.""" - -import discord -from discord.ext.commands import Bot - -from botcore import loggers - -log = loggers.get_logger(__name__) - - -def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: - """Return True if `channel` is within a category with `category_id`.""" - return getattr(channel, "category_id", None) == category_id - - -async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: - """Attempt to get or fetch a channel and return it.""" - log.trace(f"Getting the channel {channel_id}.") - - channel = bot.get_channel(channel_id) - if not channel: - log.debug(f"Channel {channel_id} is not in cache; fetching from API.") - channel = await bot.fetch_channel(channel_id) - - log.trace(f"Channel #{channel} ({channel_id}) retrieved.") - return channel diff --git a/botcore/extensions.py b/botcore/extensions.py deleted file mode 100644 index c8f200ad..00000000 --- a/botcore/extensions.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Utilities for loading discord extensions.""" - -import importlib -import inspect -import pkgutil -import types -from typing import NoReturn - - -def unqualify(name: str) -> str: - """ - Return an unqualified name given a qualified module/package `name`. - - Args: - name: The module name to unqualify. - - Returns: - The unqualified module name. - """ - return name.rsplit(".", maxsplit=1)[-1] - - -def walk_extensions(module: types.ModuleType) -> frozenset[str]: - """ - Yield extension names from the bot.exts subpackage. - - Args: - module (types.ModuleType): The module to look for extensions in. - - Returns: - A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. - """ - - def on_error(name: str) -> NoReturn: - raise ImportError(name=name) # pragma: no cover - - modules = set() - - for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): - if unqualify(module_info.name).startswith("_"): - # Ignore module/package names starting with an underscore. - continue - - if module_info.ispkg: - imported = importlib.import_module(module_info.name) - if not inspect.isfunction(getattr(imported, "setup", None)): - # If it lacks a setup function, it's not an extension. - continue - - modules.add(module_info.name) - - return frozenset(modules) diff --git a/botcore/loggers.py b/botcore/loggers.py deleted file mode 100644 index ac1db920..00000000 --- a/botcore/loggers.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Custom logging class.""" - -import logging -import typing - -if typing.TYPE_CHECKING: - LoggerClass = logging.Logger -else: - LoggerClass = logging.getLoggerClass() - -TRACE_LEVEL = 5 - - -class CustomLogger(LoggerClass): - """Custom implementation of the `Logger` class with an added `trace` method.""" - - def trace(self, msg: str, *args, **kwargs) -> None: - """ - Log 'msg % args' with severity 'TRACE'. - - To pass exception information, use the keyword argument exc_info with a true value: - - .. code-block:: py - - logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) - - Args: - msg: The message to be logged. - args, kwargs: Passed to the base log function as is. - """ - if self.isEnabledFor(TRACE_LEVEL): - self.log(TRACE_LEVEL, msg, *args, **kwargs) - - -def get_logger(name: typing.Optional[str] = None) -> CustomLogger: - """ - Utility to make mypy recognise that logger is of type `CustomLogger`. - - Args: - name: The name given to the logger. - - Returns: - An instance of the `CustomLogger` class. - """ - return typing.cast(CustomLogger, logging.getLogger(name)) diff --git a/botcore/members.py b/botcore/members.py deleted file mode 100644 index 07b16ea3..00000000 --- a/botcore/members.py +++ /dev/null @@ -1,48 +0,0 @@ -import typing - -import discord - -from botcore.loggers import get_logger - -log = get_logger(__name__) - - -async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: - """ - Attempt to get a member from cache; on failure fetch from the API. - - Return `None` to indicate the member could not be found. - """ - if member := guild.get_member(member_id): - log.trace(f"{member} retrieved from cache.") - else: - try: - member = await guild.fetch_member(member_id) - except discord.errors.NotFound: - log.trace(f"Failed to fetch {member_id} from API.") - return None - log.trace(f"{member} fetched from API.") - return member - - -async def handle_role_change( - member: discord.Member, - coro: typing.Callable[..., typing.Coroutine], - role: discord.Role -) -> None: - """ - Change `member`'s cooldown role via awaiting `coro` and handle errors. - - `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. - """ - try: - await coro(role) - except discord.NotFound: - log.error(f"Failed to change role for {member} ({member.id}): member not found") - except discord.Forbidden: - log.error( - f"Forbidden to change role for {member} ({member.id}); " - f"possibly due to role hierarchy" - ) - except discord.HTTPException as e: - log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/botcore/regex.py b/botcore/regex.py deleted file mode 100644 index 036a5113..00000000 --- a/botcore/regex.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Common regular expressions.""" - -import re - -DISCORD_INVITE = re.compile( - r"(discord([.,]|dot)gg|" # Could be discord.gg/ - r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ - r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ - r"discord([.,]|dot)me|" # or discord.me - r"discord([.,]|dot)li|" # or discord.li - r"discord([.,]|dot)io|" # or discord.io. - r"((?[a-zA-Z0-9\-]+)", # the invite code itself - flags=re.IGNORECASE -) -""" -Regex for discord server invites. - -:meta hide-value: -""" - -FORMATTED_CODE_REGEX = re.compile( - r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block - r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) - r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all code inside the markup - r"\s*" # any more whitespace before the end of the code markup - r"(?P=delim)", # match the exact same delimiter from the start again - re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive -) -""" -Regex for formatted code, using Discord's code blocks. - -:meta hide-value: -""" - -RAW_CODE_REGEX = re.compile( - r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code - r"(?P.*?)" # extract all the rest as code - r"\s*$", # any trailing whitespace until the end of the string - re.DOTALL # "." also matches newlines -) -""" -Regex for raw code, *not* using Discord's code blocks. - -:meta hide-value: -""" diff --git a/botcore/scheduling.py b/botcore/scheduling.py deleted file mode 100644 index 206e5e79..00000000 --- a/botcore/scheduling.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Generic python scheduler.""" - -import asyncio -import contextlib -import inspect -import typing -from datetime import datetime -from functools import partial - -from botcore.loggers import get_logger - - -class Scheduler: - """ - Schedule the execution of coroutines and keep track of them. - - When instantiating a Scheduler, a name must be provided. This name is used to distinguish the - instance's log messages from other instances. Using the name of the class or module containing - the instance is suggested. - - Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` - or `schedule_later`. A unique ID is required to be given in order to keep track of the - resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing - the same ID used to schedule it. The `in` operator is supported for checking if a task with a - given ID is currently scheduled. - - Any exception raised in a scheduled task is logged when the task is done. - """ - - def __init__(self, name: str): - """ - Initialize a new Scheduler instance. - - Args: - name: The name of the scheduler. Used in logging, and namespacing. - """ - self.name = name - - self._log = get_logger(f"{__name__}.{name}") - self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {} - - def __contains__(self, task_id: typing.Hashable) -> bool: - """ - Return True if a task with the given `task_id` is currently scheduled. - - Args: - task_id: The task to look for. - - Returns: - True if the task was found. - """ - return task_id in self._scheduled_tasks - - def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: - """ - Schedule the execution of a `coroutine`. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - self._log.trace(f"Scheduling task #{task_id}...") - - msg = f"Cannot schedule an already started coroutine for #{task_id}" - assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg - - if task_id in self._scheduled_tasks: - self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") - coroutine.close() - return - - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") - task.add_done_callback(partial(self._task_done_callback, task_id)) - - self._scheduled_tasks[task_id] = task - self._log.debug(f"Scheduled task #{task_id} {id(task)}.") - - def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: - """ - Schedule `coroutine` to be executed at the given `time`. - - If `time` is timezone aware, then use that timezone to calculate now() when subtracting. - If `time` is naïve, then use UTC. - - If `time` is in the past, schedule `coroutine` immediately. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - time: The time to start the task. - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() - delay = (time - now_datetime).total_seconds() - if delay > 0: - coroutine = self._await_later(delay, task_id, coroutine) - - self.schedule(task_id, coroutine) - - def schedule_later( - self, - delay: typing.Union[int, float], - task_id: typing.Hashable, - coroutine: typing.Coroutine - ) -> None: - """ - Schedule `coroutine` to be executed after `delay` seconds. - - If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This - prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. - - Args: - delay: How long to wait before starting the task. - task_id: A unique ID to create the task with. - coroutine: The function to be called. - """ - self.schedule(task_id, self._await_later(delay, task_id, coroutine)) - - def cancel(self, task_id: typing.Hashable) -> None: - """ - Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist. - - Args: - task_id: The task's unique ID. - """ - self._log.trace(f"Cancelling task #{task_id}...") - - try: - task = self._scheduled_tasks.pop(task_id) - except KeyError: - self._log.warning(f"Failed to unschedule {task_id} (no task found).") - else: - task.cancel() - - self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") - - def cancel_all(self) -> None: - """Unschedule all known tasks.""" - self._log.debug("Unscheduling all tasks") - - for task_id in self._scheduled_tasks.copy(): - self.cancel(task_id) - - async def _await_later( - self, - delay: typing.Union[int, float], - task_id: typing.Hashable, - coroutine: typing.Coroutine - ) -> None: - """Await `coroutine` after `delay` seconds.""" - try: - self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") - await asyncio.sleep(delay) - - # Use asyncio.shield to prevent the coroutine from cancelling itself. - self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") - await asyncio.shield(coroutine) - finally: - # Close it to prevent unawaited coroutine warnings, - # which would happen if the task was cancelled during the sleep. - # Only close it if it's not been awaited yet. This check is important because the - # coroutine may cancel this task, which would also trigger the finally block. - state = inspect.getcoroutinestate(coroutine) - if state == "CORO_CREATED": - self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") - coroutine.close() - else: - self._log.debug(f"Finally block reached for #{task_id}; {state=}") - - def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None: - """ - Delete the task and raise its exception if one exists. - - If `done_task` and the task associated with `task_id` are different, then the latter - will not be deleted. In this case, a new task was likely rescheduled with the same ID. - """ - self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") - - scheduled_task = self._scheduled_tasks.get(task_id) - - if scheduled_task and done_task is scheduled_task: - # A task for the ID exists and is the same as the done task. - # Since this is the done callback, the task is already done so no need to cancel it. - self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") - del self._scheduled_tasks[task_id] - elif scheduled_task: - # A new task was likely rescheduled with the same ID. - self._log.debug( - f"The scheduled task #{task_id} {id(scheduled_task)} " - f"and the done task {id(done_task)} differ." - ) - elif not done_task.cancelled(): - self._log.warning( - f"Task #{task_id} not found while handling task {id(done_task)}! " - f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." - ) - - with contextlib.suppress(asyncio.CancelledError): - exception = done_task.exception() - # Log the exception if one exists. - if exception: - self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) - - -def create_task( - coro: typing.Awaitable, - *, - suppressed_exceptions: tuple[typing.Type[Exception]] = (), - event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, - **kwargs, -) -> asyncio.Task: - """ - Wrapper for creating asyncio `Tasks` which logs exceptions raised in the task. - - If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. - - Args: - coro: The function to call. - suppressed_exceptions: Exceptions to be handled by the task. - event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. - kwargs: Passed to :py:func:`asyncio.create_task`. - - Returns: - asyncio.Task: The wrapped task. - """ - if event_loop is not None: - task = event_loop.create_task(coro, **kwargs) - else: - task = asyncio.create_task(coro, **kwargs) - task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) - return task - - -def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None: - """Retrieve and log the exception raised in `task` if one exists.""" - with contextlib.suppress(asyncio.CancelledError): - exception = task.exception() - # Log the exception if one exists. - if exception and not isinstance(exception, suppressed_exceptions): - log = get_logger(__name__) - log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py new file mode 100644 index 00000000..554e8ad1 --- /dev/null +++ b/botcore/utils/__init__.py @@ -0,0 +1,15 @@ +"""Useful utilities and tools for discord bot development.""" + +from botcore.utils import (caching, channel, extensions, loggers, members, regex, scheduling) + +__all__ = [ + caching, + channel, + extensions, + loggers, + members, + regex, + scheduling, +] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/utils/caching.py b/botcore/utils/caching.py new file mode 100644 index 00000000..ea71ed1d --- /dev/null +++ b/botcore/utils/caching.py @@ -0,0 +1,65 @@ +"""Utilities related to custom caches.""" + +import functools +import typing +from collections import OrderedDict + + +class AsyncCache: + """ + LRU cache implementation for coroutines. + + Once the cache exceeds the maximum size, keys are deleted in FIFO order. + + An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. + """ + + def __init__(self, max_size: int = 128): + """ + Initialise a new AsyncCache instance. + + Args: + max_size: How many items to store in the cache. + """ + self._cache = OrderedDict() + self._max_size = max_size + + def __call__(self, arg_offset: int = 0) -> typing.Callable: + """ + Decorator for async cache. + + Args: + arg_offset: The offset for the position of the key argument. + + Returns: + A decorator to wrap the target function. + """ + + def decorator(function: typing.Callable) -> typing.Callable: + """ + Define the async cache decorator. + + Args: + function: The function to wrap. + + Returns: + The wrapped function. + """ + + @functools.wraps(function) + async def wrapper(*args) -> typing.Any: + """Decorator wrapper for the caching logic.""" + key = args[arg_offset:] + + if key not in self._cache: + if len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + self._cache[key] = await function(*args) + return self._cache[key] + return wrapper + return decorator + + def clear(self) -> None: + """Clear cache instance.""" + self._cache.clear() diff --git a/botcore/utils/channel.py b/botcore/utils/channel.py new file mode 100644 index 00000000..7e0fc387 --- /dev/null +++ b/botcore/utils/channel.py @@ -0,0 +1,26 @@ +"""Utilities for interacting with discord channels.""" + +import discord +from discord.ext.commands import Bot + +from botcore.utils import loggers + +log = loggers.get_logger(__name__) + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: + """Return True if `channel` is within a category with `category_id`.""" + return getattr(channel, "category_id", None) == category_id + + +async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: + """Attempt to get or fetch a channel and return it.""" + log.trace(f"Getting the channel {channel_id}.") + + channel = bot.get_channel(channel_id) + if not channel: + log.debug(f"Channel {channel_id} is not in cache; fetching from API.") + channel = await bot.fetch_channel(channel_id) + + log.trace(f"Channel #{channel} ({channel_id}) retrieved.") + return channel diff --git a/botcore/utils/extensions.py b/botcore/utils/extensions.py new file mode 100644 index 00000000..c8f200ad --- /dev/null +++ b/botcore/utils/extensions.py @@ -0,0 +1,52 @@ +"""Utilities for loading discord extensions.""" + +import importlib +import inspect +import pkgutil +import types +from typing import NoReturn + + +def unqualify(name: str) -> str: + """ + Return an unqualified name given a qualified module/package `name`. + + Args: + name: The module name to unqualify. + + Returns: + The unqualified module name. + """ + return name.rsplit(".", maxsplit=1)[-1] + + +def walk_extensions(module: types.ModuleType) -> frozenset[str]: + """ + Yield extension names from the bot.exts subpackage. + + Args: + module (types.ModuleType): The module to look for extensions in. + + Returns: + A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. + """ + + def on_error(name: str) -> NoReturn: + raise ImportError(name=name) # pragma: no cover + + modules = set() + + for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): + if unqualify(module_info.name).startswith("_"): + # Ignore module/package names starting with an underscore. + continue + + if module_info.ispkg: + imported = importlib.import_module(module_info.name) + if not inspect.isfunction(getattr(imported, "setup", None)): + # If it lacks a setup function, it's not an extension. + continue + + modules.add(module_info.name) + + return frozenset(modules) diff --git a/botcore/utils/loggers.py b/botcore/utils/loggers.py new file mode 100644 index 00000000..ac1db920 --- /dev/null +++ b/botcore/utils/loggers.py @@ -0,0 +1,45 @@ +"""Custom logging class.""" + +import logging +import typing + +if typing.TYPE_CHECKING: + LoggerClass = logging.Logger +else: + LoggerClass = logging.getLoggerClass() + +TRACE_LEVEL = 5 + + +class CustomLogger(LoggerClass): + """Custom implementation of the `Logger` class with an added `trace` method.""" + + def trace(self, msg: str, *args, **kwargs) -> None: + """ + Log 'msg % args' with severity 'TRACE'. + + To pass exception information, use the keyword argument exc_info with a true value: + + .. code-block:: py + + logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) + + Args: + msg: The message to be logged. + args, kwargs: Passed to the base log function as is. + """ + if self.isEnabledFor(TRACE_LEVEL): + self.log(TRACE_LEVEL, msg, *args, **kwargs) + + +def get_logger(name: typing.Optional[str] = None) -> CustomLogger: + """ + Utility to make mypy recognise that logger is of type `CustomLogger`. + + Args: + name: The name given to the logger. + + Returns: + An instance of the `CustomLogger` class. + """ + return typing.cast(CustomLogger, logging.getLogger(name)) diff --git a/botcore/utils/members.py b/botcore/utils/members.py new file mode 100644 index 00000000..abe7e5e1 --- /dev/null +++ b/botcore/utils/members.py @@ -0,0 +1,48 @@ +import typing + +import discord + +from botcore.utils import loggers + +log = loggers.get_logger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: + """ + Attempt to get a member from cache; on failure fetch from the API. + + Return `None` to indicate the member could not be found. + """ + if member := guild.get_member(member_id): + log.trace(f"{member} retrieved from cache.") + else: + try: + member = await guild.fetch_member(member_id) + except discord.errors.NotFound: + log.trace(f"Failed to fetch {member_id} from API.") + return None + log.trace(f"{member} fetched from API.") + return member + + +async def handle_role_change( + member: discord.Member, + coro: typing.Callable[..., typing.Coroutine], + role: discord.Role +) -> None: + """ + Change `member`'s cooldown role via awaiting `coro` and handle errors. + + `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. + """ + try: + await coro(role) + except discord.NotFound: + log.error(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.error( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py new file mode 100644 index 00000000..036a5113 --- /dev/null +++ b/botcore/utils/regex.py @@ -0,0 +1,48 @@ +"""Common regular expressions.""" + +import re + +DISCORD_INVITE = re.compile( + r"(discord([.,]|dot)gg|" # Could be discord.gg/ + r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/ + r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/ + r"discord([.,]|dot)me|" # or discord.me + r"discord([.,]|dot)li|" # or discord.li + r"discord([.,]|dot)io|" # or discord.io. + r"((?[a-zA-Z0-9\-]+)", # the invite code itself + flags=re.IGNORECASE +) +""" +Regex for discord server invites. + +:meta hide-value: +""" + +FORMATTED_CODE_REGEX = re.compile( + r"(?P(?P```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block + r"(?(block)(?:(?P[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline) + r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all code inside the markup + r"\s*" # any more whitespace before the end of the code markup + r"(?P=delim)", # match the exact same delimiter from the start again + re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive +) +""" +Regex for formatted code, using Discord's code blocks. + +:meta hide-value: +""" + +RAW_CODE_REGEX = re.compile( + r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code + r"(?P.*?)" # extract all the rest as code + r"\s*$", # any trailing whitespace until the end of the string + re.DOTALL # "." also matches newlines +) +""" +Regex for raw code, *not* using Discord's code blocks. + +:meta hide-value: +""" diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py new file mode 100644 index 00000000..947df0d9 --- /dev/null +++ b/botcore/utils/scheduling.py @@ -0,0 +1,246 @@ +"""Generic python scheduler.""" + +import asyncio +import contextlib +import inspect +import typing +from datetime import datetime +from functools import partial + +from botcore.utils import loggers + + +class Scheduler: + """ + Schedule the execution of coroutines and keep track of them. + + When instantiating a Scheduler, a name must be provided. This name is used to distinguish the + instance's log messages from other instances. Using the name of the class or module containing + the instance is suggested. + + Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` + or `schedule_later`. A unique ID is required to be given in order to keep track of the + resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing + the same ID used to schedule it. The `in` operator is supported for checking if a task with a + given ID is currently scheduled. + + Any exception raised in a scheduled task is logged when the task is done. + """ + + def __init__(self, name: str): + """ + Initialize a new Scheduler instance. + + Args: + name: The name of the scheduler. Used in logging, and namespacing. + """ + self.name = name + + self._log = loggers.get_logger(f"{__name__}.{name}") + self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {} + + def __contains__(self, task_id: typing.Hashable) -> bool: + """ + Return True if a task with the given `task_id` is currently scheduled. + + Args: + task_id: The task to look for. + + Returns: + True if the task was found. + """ + return task_id in self._scheduled_tasks + + def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: + """ + Schedule the execution of a `coroutine`. + + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self._log.trace(f"Scheduling task #{task_id}...") + + msg = f"Cannot schedule an already started coroutine for #{task_id}" + assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg + + if task_id in self._scheduled_tasks: + self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") + coroutine.close() + return + + task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") + task.add_done_callback(partial(self._task_done_callback, task_id)) + + self._scheduled_tasks[task_id] = task + self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + + def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: + """ + Schedule `coroutine` to be executed at the given `time`. + + If `time` is timezone aware, then use that timezone to calculate now() when subtracting. + If `time` is naïve, then use UTC. + + If `time` is in the past, schedule `coroutine` immediately. + + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + time: The time to start the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() + delay = (time - now_datetime).total_seconds() + if delay > 0: + coroutine = self._await_later(delay, task_id, coroutine) + + self.schedule(task_id, coroutine) + + def schedule_later( + self, + delay: typing.Union[int, float], + task_id: typing.Hashable, + coroutine: typing.Coroutine + ) -> None: + """ + Schedule `coroutine` to be executed after `delay` seconds. + + If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This + prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + + Args: + delay: How long to wait before starting the task. + task_id: A unique ID to create the task with. + coroutine: The function to be called. + """ + self.schedule(task_id, self._await_later(delay, task_id, coroutine)) + + def cancel(self, task_id: typing.Hashable) -> None: + """ + Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist. + + Args: + task_id: The task's unique ID. + """ + self._log.trace(f"Cancelling task #{task_id}...") + + try: + task = self._scheduled_tasks.pop(task_id) + except KeyError: + self._log.warning(f"Failed to unschedule {task_id} (no task found).") + else: + task.cancel() + + self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") + + def cancel_all(self) -> None: + """Unschedule all known tasks.""" + self._log.debug("Unscheduling all tasks") + + for task_id in self._scheduled_tasks.copy(): + self.cancel(task_id) + + async def _await_later( + self, + delay: typing.Union[int, float], + task_id: typing.Hashable, + coroutine: typing.Coroutine + ) -> None: + """Await `coroutine` after `delay` seconds.""" + try: + self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") + await asyncio.sleep(delay) + + # Use asyncio.shield to prevent the coroutine from cancelling itself. + self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") + await asyncio.shield(coroutine) + finally: + # Close it to prevent unawaited coroutine warnings, + # which would happen if the task was cancelled during the sleep. + # Only close it if it's not been awaited yet. This check is important because the + # coroutine may cancel this task, which would also trigger the finally block. + state = inspect.getcoroutinestate(coroutine) + if state == "CORO_CREATED": + self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") + coroutine.close() + else: + self._log.debug(f"Finally block reached for #{task_id}; {state=}") + + def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None: + """ + Delete the task and raise its exception if one exists. + + If `done_task` and the task associated with `task_id` are different, then the latter + will not be deleted. In this case, a new task was likely rescheduled with the same ID. + """ + self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") + + scheduled_task = self._scheduled_tasks.get(task_id) + + if scheduled_task and done_task is scheduled_task: + # A task for the ID exists and is the same as the done task. + # Since this is the done callback, the task is already done so no need to cancel it. + self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") + del self._scheduled_tasks[task_id] + elif scheduled_task: + # A new task was likely rescheduled with the same ID. + self._log.debug( + f"The scheduled task #{task_id} {id(scheduled_task)} " + f"and the done task {id(done_task)} differ." + ) + elif not done_task.cancelled(): + self._log.warning( + f"Task #{task_id} not found while handling task {id(done_task)}! " + f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." + ) + + with contextlib.suppress(asyncio.CancelledError): + exception = done_task.exception() + # Log the exception if one exists. + if exception: + self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) + + +def create_task( + coro: typing.Awaitable, + *, + suppressed_exceptions: tuple[typing.Type[Exception]] = (), + event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, + **kwargs, +) -> asyncio.Task: + """ + Wrapper for creating asyncio `Tasks` which logs exceptions raised in the task. + + If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. + + Args: + coro: The function to call. + suppressed_exceptions: Exceptions to be handled by the task. + event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. + kwargs: Passed to :py:func:`asyncio.create_task`. + + Returns: + asyncio.Task: The wrapped task. + """ + if event_loop is not None: + task = event_loop.create_task(coro, **kwargs) + else: + task = asyncio.create_task(coro, **kwargs) + task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) + return task + + +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None: + """Retrieve and log the exception raised in `task` if one exists.""" + with contextlib.suppress(asyncio.CancelledError): + exception = task.exception() + # Log the exception if one exists. + if exception and not isinstance(exception, suppressed_exceptions): + log = loggers.get_logger(__name__) + log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) -- cgit v1.2.3