summaryrefslogtreecommitdiffstats
path: root/botcore/utils
diff options
context:
space:
mode:
Diffstat (limited to 'botcore/utils')
-rw-r--r--botcore/utils/__init__.py15
-rw-r--r--botcore/utils/caching.py65
-rw-r--r--botcore/utils/channel.py26
-rw-r--r--botcore/utils/extensions.py52
-rw-r--r--botcore/utils/loggers.py45
-rw-r--r--botcore/utils/members.py48
-rw-r--r--botcore/utils/regex.py48
-rw-r--r--botcore/utils/scheduling.py246
8 files changed, 545 insertions, 0 deletions
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py
new file mode 100644
index 00000000..554e8ad1
--- /dev/null
+++ b/botcore/utils/__init__.py
@@ -0,0 +1,15 @@
+"""Useful utilities and tools for discord bot development."""
+
+from botcore.utils import (caching, channel, extensions, loggers, members, regex, scheduling)
+
+__all__ = [
+ caching,
+ channel,
+ extensions,
+ loggers,
+ members,
+ regex,
+ scheduling,
+]
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/utils/caching.py b/botcore/utils/caching.py
new file mode 100644
index 00000000..ea71ed1d
--- /dev/null
+++ b/botcore/utils/caching.py
@@ -0,0 +1,65 @@
+"""Utilities related to custom caches."""
+
+import functools
+import typing
+from collections import OrderedDict
+
+
+class AsyncCache:
+ """
+ LRU cache implementation for coroutines.
+
+ Once the cache exceeds the maximum size, keys are deleted in FIFO order.
+
+ An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key.
+ """
+
+ def __init__(self, max_size: int = 128):
+ """
+ Initialise a new AsyncCache instance.
+
+ Args:
+ max_size: How many items to store in the cache.
+ """
+ self._cache = OrderedDict()
+ self._max_size = max_size
+
+ def __call__(self, arg_offset: int = 0) -> typing.Callable:
+ """
+ Decorator for async cache.
+
+ Args:
+ arg_offset: The offset for the position of the key argument.
+
+ Returns:
+ A decorator to wrap the target function.
+ """
+
+ def decorator(function: typing.Callable) -> typing.Callable:
+ """
+ Define the async cache decorator.
+
+ Args:
+ function: The function to wrap.
+
+ Returns:
+ The wrapped function.
+ """
+
+ @functools.wraps(function)
+ async def wrapper(*args) -> typing.Any:
+ """Decorator wrapper for the caching logic."""
+ key = args[arg_offset:]
+
+ if key not in self._cache:
+ if len(self._cache) > self._max_size:
+ self._cache.popitem(last=False)
+
+ self._cache[key] = await function(*args)
+ return self._cache[key]
+ return wrapper
+ return decorator
+
+ def clear(self) -> None:
+ """Clear cache instance."""
+ self._cache.clear()
diff --git a/botcore/utils/channel.py b/botcore/utils/channel.py
new file mode 100644
index 00000000..7e0fc387
--- /dev/null
+++ b/botcore/utils/channel.py
@@ -0,0 +1,26 @@
+"""Utilities for interacting with discord channels."""
+
+import discord
+from discord.ext.commands import Bot
+
+from botcore.utils import loggers
+
+log = loggers.get_logger(__name__)
+
+
+def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
+ """Return True if `channel` is within a category with `category_id`."""
+ return getattr(channel, "category_id", None) == category_id
+
+
+async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel:
+ """Attempt to get or fetch a channel and return it."""
+ log.trace(f"Getting the channel {channel_id}.")
+
+ channel = bot.get_channel(channel_id)
+ if not channel:
+ log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
+ channel = await bot.fetch_channel(channel_id)
+
+ log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
+ return channel
diff --git a/botcore/utils/extensions.py b/botcore/utils/extensions.py
new file mode 100644
index 00000000..c8f200ad
--- /dev/null
+++ b/botcore/utils/extensions.py
@@ -0,0 +1,52 @@
+"""Utilities for loading discord extensions."""
+
+import importlib
+import inspect
+import pkgutil
+import types
+from typing import NoReturn
+
+
+def unqualify(name: str) -> str:
+ """
+ Return an unqualified name given a qualified module/package `name`.
+
+ Args:
+ name: The module name to unqualify.
+
+ Returns:
+ The unqualified module name.
+ """
+ return name.rsplit(".", maxsplit=1)[-1]
+
+
+def walk_extensions(module: types.ModuleType) -> frozenset[str]:
+ """
+ Yield extension names from the bot.exts subpackage.
+
+ Args:
+ module (types.ModuleType): The module to look for extensions in.
+
+ Returns:
+ A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`.
+ """
+
+ def on_error(name: str) -> NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ modules = set()
+
+ for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error):
+ if unqualify(module_info.name).startswith("_"):
+ # Ignore module/package names starting with an underscore.
+ continue
+
+ if module_info.ispkg:
+ imported = importlib.import_module(module_info.name)
+ if not inspect.isfunction(getattr(imported, "setup", None)):
+ # If it lacks a setup function, it's not an extension.
+ continue
+
+ modules.add(module_info.name)
+
+ return frozenset(modules)
diff --git a/botcore/utils/loggers.py b/botcore/utils/loggers.py
new file mode 100644
index 00000000..ac1db920
--- /dev/null
+++ b/botcore/utils/loggers.py
@@ -0,0 +1,45 @@
+"""Custom logging class."""
+
+import logging
+import typing
+
+if typing.TYPE_CHECKING:
+ LoggerClass = logging.Logger
+else:
+ LoggerClass = logging.getLoggerClass()
+
+TRACE_LEVEL = 5
+
+
+class CustomLogger(LoggerClass):
+ """Custom implementation of the `Logger` class with an added `trace` method."""
+
+ def trace(self, msg: str, *args, **kwargs) -> None:
+ """
+ Log 'msg % args' with severity 'TRACE'.
+
+ To pass exception information, use the keyword argument exc_info with a true value:
+
+ .. code-block:: py
+
+ logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
+
+ Args:
+ msg: The message to be logged.
+ args, kwargs: Passed to the base log function as is.
+ """
+ if self.isEnabledFor(TRACE_LEVEL):
+ self.log(TRACE_LEVEL, msg, *args, **kwargs)
+
+
+def get_logger(name: typing.Optional[str] = None) -> CustomLogger:
+ """
+ Utility to make mypy recognise that logger is of type `CustomLogger`.
+
+ Args:
+ name: The name given to the logger.
+
+ Returns:
+ An instance of the `CustomLogger` class.
+ """
+ return typing.cast(CustomLogger, logging.getLogger(name))
diff --git a/botcore/utils/members.py b/botcore/utils/members.py
new file mode 100644
index 00000000..abe7e5e1
--- /dev/null
+++ b/botcore/utils/members.py
@@ -0,0 +1,48 @@
+import typing
+
+import discord
+
+from botcore.utils import loggers
+
+log = loggers.get_logger(__name__)
+
+
+async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]:
+ """
+ Attempt to get a member from cache; on failure fetch from the API.
+
+ Return `None` to indicate the member could not be found.
+ """
+ if member := guild.get_member(member_id):
+ log.trace(f"{member} retrieved from cache.")
+ else:
+ try:
+ member = await guild.fetch_member(member_id)
+ except discord.errors.NotFound:
+ log.trace(f"Failed to fetch {member_id} from API.")
+ return None
+ log.trace(f"{member} fetched from API.")
+ return member
+
+
+async def handle_role_change(
+ member: discord.Member,
+ coro: typing.Callable[..., typing.Coroutine],
+ role: discord.Role
+) -> None:
+ """
+ Change `member`'s cooldown role via awaiting `coro` and handle errors.
+
+ `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`.
+ """
+ try:
+ await coro(role)
+ except discord.NotFound:
+ log.error(f"Failed to change role for {member} ({member.id}): member not found")
+ except discord.Forbidden:
+ log.error(
+ f"Forbidden to change role for {member} ({member.id}); "
+ f"possibly due to role hierarchy"
+ )
+ except discord.HTTPException as e:
+ log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}")
diff --git a/botcore/utils/regex.py b/botcore/utils/regex.py
new file mode 100644
index 00000000..036a5113
--- /dev/null
+++ b/botcore/utils/regex.py
@@ -0,0 +1,48 @@
+"""Common regular expressions."""
+
+import re
+
+DISCORD_INVITE = re.compile(
+ r"(discord([.,]|dot)gg|" # Could be discord.gg/
+ r"discord([.,]|dot)com(/|slash)invite|" # or discord.com/invite/
+ r"discordapp([.,]|dot)com(/|slash)invite|" # or discordapp.com/invite/
+ r"discord([.,]|dot)me|" # or discord.me
+ r"discord([.,]|dot)li|" # or discord.li
+ r"discord([.,]|dot)io|" # or discord.io.
+ r"((?<!\w)([.,]|dot))gg" # or .gg/
+ r")([/]|slash)" # / or 'slash'
+ r"(?P<invite>[a-zA-Z0-9\-]+)", # the invite code itself
+ flags=re.IGNORECASE
+)
+"""
+Regex for discord server invites.
+
+:meta hide-value:
+"""
+
+FORMATTED_CODE_REGEX = re.compile(
+ r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
+ r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
+ r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all code inside the markup
+ r"\s*" # any more whitespace before the end of the code markup
+ r"(?P=delim)", # match the exact same delimiter from the start again
+ re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
+)
+"""
+Regex for formatted code, using Discord's code blocks.
+
+:meta hide-value:
+"""
+
+RAW_CODE_REGEX = re.compile(
+ r"^(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all the rest as code
+ r"\s*$", # any trailing whitespace until the end of the string
+ re.DOTALL # "." also matches newlines
+)
+"""
+Regex for raw code, *not* using Discord's code blocks.
+
+:meta hide-value:
+"""
diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py
new file mode 100644
index 00000000..947df0d9
--- /dev/null
+++ b/botcore/utils/scheduling.py
@@ -0,0 +1,246 @@
+"""Generic python scheduler."""
+
+import asyncio
+import contextlib
+import inspect
+import typing
+from datetime import datetime
+from functools import partial
+
+from botcore.utils import loggers
+
+
+class Scheduler:
+ """
+ Schedule the execution of coroutines and keep track of them.
+
+ When instantiating a Scheduler, a name must be provided. This name is used to distinguish the
+ instance's log messages from other instances. Using the name of the class or module containing
+ the instance is suggested.
+
+ Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at`
+ or `schedule_later`. A unique ID is required to be given in order to keep track of the
+ resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing
+ the same ID used to schedule it. The `in` operator is supported for checking if a task with a
+ given ID is currently scheduled.
+
+ Any exception raised in a scheduled task is logged when the task is done.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize a new Scheduler instance.
+
+ Args:
+ name: The name of the scheduler. Used in logging, and namespacing.
+ """
+ self.name = name
+
+ self._log = loggers.get_logger(f"{__name__}.{name}")
+ self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {}
+
+ def __contains__(self, task_id: typing.Hashable) -> bool:
+ """
+ Return True if a task with the given `task_id` is currently scheduled.
+
+ Args:
+ task_id: The task to look for.
+
+ Returns:
+ True if the task was found.
+ """
+ return task_id in self._scheduled_tasks
+
+ def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule the execution of a `coroutine`.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self._log.trace(f"Scheduling task #{task_id}...")
+
+ msg = f"Cannot schedule an already started coroutine for #{task_id}"
+ assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg
+
+ if task_id in self._scheduled_tasks:
+ self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.")
+ coroutine.close()
+ return
+
+ task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}")
+ task.add_done_callback(partial(self._task_done_callback, task_id))
+
+ self._scheduled_tasks[task_id] = task
+ self._log.debug(f"Scheduled task #{task_id} {id(task)}.")
+
+ def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule `coroutine` to be executed at the given `time`.
+
+ If `time` is timezone aware, then use that timezone to calculate now() when subtracting.
+ If `time` is naïve, then use UTC.
+
+ If `time` is in the past, schedule `coroutine` immediately.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ time: The time to start the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow()
+ delay = (time - now_datetime).total_seconds()
+ if delay > 0:
+ coroutine = self._await_later(delay, task_id, coroutine)
+
+ self.schedule(task_id, coroutine)
+
+ def schedule_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """
+ Schedule `coroutine` to be executed after `delay` seconds.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ delay: How long to wait before starting the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self.schedule(task_id, self._await_later(delay, task_id, coroutine))
+
+ def cancel(self, task_id: typing.Hashable) -> None:
+ """
+ Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.
+
+ Args:
+ task_id: The task's unique ID.
+ """
+ self._log.trace(f"Cancelling task #{task_id}...")
+
+ try:
+ task = self._scheduled_tasks.pop(task_id)
+ except KeyError:
+ self._log.warning(f"Failed to unschedule {task_id} (no task found).")
+ else:
+ task.cancel()
+
+ self._log.debug(f"Unscheduled task #{task_id} {id(task)}.")
+
+ def cancel_all(self) -> None:
+ """Unschedule all known tasks."""
+ self._log.debug("Unscheduling all tasks")
+
+ for task_id in self._scheduled_tasks.copy():
+ self.cancel(task_id)
+
+ async def _await_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """Await `coroutine` after `delay` seconds."""
+ try:
+ self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.")
+ await asyncio.sleep(delay)
+
+ # Use asyncio.shield to prevent the coroutine from cancelling itself.
+ self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.")
+ await asyncio.shield(coroutine)
+ finally:
+ # Close it to prevent unawaited coroutine warnings,
+ # which would happen if the task was cancelled during the sleep.
+ # Only close it if it's not been awaited yet. This check is important because the
+ # coroutine may cancel this task, which would also trigger the finally block.
+ state = inspect.getcoroutinestate(coroutine)
+ if state == "CORO_CREATED":
+ self._log.debug(f"Explicitly closing the coroutine for #{task_id}.")
+ coroutine.close()
+ else:
+ self._log.debug(f"Finally block reached for #{task_id}; {state=}")
+
+ def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None:
+ """
+ Delete the task and raise its exception if one exists.
+
+ If `done_task` and the task associated with `task_id` are different, then the latter
+ will not be deleted. In this case, a new task was likely rescheduled with the same ID.
+ """
+ self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.")
+
+ scheduled_task = self._scheduled_tasks.get(task_id)
+
+ if scheduled_task and done_task is scheduled_task:
+ # A task for the ID exists and is the same as the done task.
+ # Since this is the done callback, the task is already done so no need to cancel it.
+ self._log.trace(f"Deleting task #{task_id} {id(done_task)}.")
+ del self._scheduled_tasks[task_id]
+ elif scheduled_task:
+ # A new task was likely rescheduled with the same ID.
+ self._log.debug(
+ f"The scheduled task #{task_id} {id(scheduled_task)} "
+ f"and the done task {id(done_task)} differ."
+ )
+ elif not done_task.cancelled():
+ self._log.warning(
+ f"Task #{task_id} not found while handling task {id(done_task)}! "
+ f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."
+ )
+
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = done_task.exception()
+ # Log the exception if one exists.
+ if exception:
+ self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
+
+
+def create_task(
+ coro: typing.Awaitable,
+ *,
+ suppressed_exceptions: tuple[typing.Type[Exception]] = (),
+ event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
+ **kwargs,
+) -> asyncio.Task:
+ """
+ Wrapper for creating asyncio `Tasks` which logs exceptions raised in the task.
+
+ If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used.
+
+ Args:
+ coro: The function to call.
+ suppressed_exceptions: Exceptions to be handled by the task.
+ event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from.
+ kwargs: Passed to :py:func:`asyncio.create_task`.
+
+ Returns:
+ asyncio.Task: The wrapped task.
+ """
+ if event_loop is not None:
+ task = event_loop.create_task(coro, **kwargs)
+ else:
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
+ return task
+
+
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None:
+ """Retrieve and log the exception raised in `task` if one exists."""
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = task.exception()
+ # Log the exception if one exists.
+ if exception and not isinstance(exception, suppressed_exceptions):
+ log = loggers.get_logger(__name__)
+ log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)