summaryrefslogtreecommitdiffstats
path: root/botcore
diff options
context:
space:
mode:
authorGravatar Hassan Abouelela <[email protected]>2022-02-24 21:58:42 +0400
committerGravatar GitHub <[email protected]>2022-02-24 21:58:42 +0400
commit554919b6314814320f35431a6cfb32ca81b09079 (patch)
treea330a6f7b8a4081be138c8f5acc3be3ca8ced99a /botcore
parentMerge pull request #23 from python-discord/bump-deps (diff)
parentUpdate GHA Docs Build To Match Pyproject (diff)
Merge pull request #29 from python-discord/port-utilities
Port utilities
Diffstat (limited to 'botcore')
-rw-r--r--botcore/__init__.py9
-rw-r--r--botcore/exts/__init__.py4
-rw-r--r--botcore/utils/__init__.py15
-rw-r--r--botcore/utils/caching.py65
-rw-r--r--botcore/utils/channel.py54
-rw-r--r--botcore/utils/extensions.py52
-rw-r--r--botcore/utils/logging.py45
-rw-r--r--botcore/utils/members.py56
-rw-r--r--botcore/utils/regex.py (renamed from botcore/regex.py)0
-rw-r--r--botcore/utils/scheduling.py248
10 files changed, 544 insertions, 4 deletions
diff --git a/botcore/__init__.py b/botcore/__init__.py
index c582d0df..d910f393 100644
--- a/botcore/__init__.py
+++ b/botcore/__init__.py
@@ -1,9 +1,10 @@
-from botcore import (
- regex,
-)
+"""Useful utilities and tools for discord bot development."""
+
+from botcore import exts, utils
__all__ = [
- regex,
+ exts,
+ utils,
]
__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/exts/__init__.py b/botcore/exts/__init__.py
new file mode 100644
index 00000000..029178a9
--- /dev/null
+++ b/botcore/exts/__init__.py
@@ -0,0 +1,4 @@
+"""Reusable discord cogs."""
+__all__ = []
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/utils/__init__.py b/botcore/utils/__init__.py
new file mode 100644
index 00000000..71354334
--- /dev/null
+++ b/botcore/utils/__init__.py
@@ -0,0 +1,15 @@
+"""Useful utilities and tools for discord bot development."""
+
+from botcore.utils import (caching, channel, extensions, logging, members, regex, scheduling)
+
+__all__ = [
+ caching,
+ channel,
+ extensions,
+ logging,
+ members,
+ regex,
+ scheduling,
+]
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/utils/caching.py b/botcore/utils/caching.py
new file mode 100644
index 00000000..ac34bb9b
--- /dev/null
+++ b/botcore/utils/caching.py
@@ -0,0 +1,65 @@
+"""Utilities related to custom caches."""
+
+import functools
+import typing
+from collections import OrderedDict
+
+
+class AsyncCache:
+ """
+ LRU cache implementation for coroutines.
+
+ Once the cache exceeds the maximum size, keys are deleted in FIFO order.
+
+ An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key.
+ """
+
+ def __init__(self, max_size: int = 128):
+ """
+ Initialise a new :obj:`AsyncCache` instance.
+
+ Args:
+ max_size: How many items to store in the cache.
+ """
+ self._cache = OrderedDict()
+ self._max_size = max_size
+
+ def __call__(self, arg_offset: int = 0) -> typing.Callable:
+ """
+ Decorator for async cache.
+
+ Args:
+ arg_offset: The offset for the position of the key argument.
+
+ Returns:
+ A decorator to wrap the target function.
+ """
+
+ def decorator(function: typing.Callable) -> typing.Callable:
+ """
+ Define the async cache decorator.
+
+ Args:
+ function: The function to wrap.
+
+ Returns:
+ The wrapped function.
+ """
+
+ @functools.wraps(function)
+ async def wrapper(*args) -> typing.Any:
+ """Decorator wrapper for the caching logic."""
+ key = args[arg_offset:]
+
+ if key not in self._cache:
+ if len(self._cache) > self._max_size:
+ self._cache.popitem(last=False)
+
+ self._cache[key] = await function(*args)
+ return self._cache[key]
+ return wrapper
+ return decorator
+
+ def clear(self) -> None:
+ """Clear cache instance."""
+ self._cache.clear()
diff --git a/botcore/utils/channel.py b/botcore/utils/channel.py
new file mode 100644
index 00000000..17e70a2a
--- /dev/null
+++ b/botcore/utils/channel.py
@@ -0,0 +1,54 @@
+"""Useful helper functions for interacting with various discord.py channel objects."""
+
+import discord
+from discord.ext.commands import Bot
+
+from botcore.utils import logging
+
+log = logging.get_logger(__name__)
+
+
+def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
+ """
+ Return whether the given ``channel`` in the the category with the id ``category_id``.
+
+ Args:
+ channel: The channel to check.
+ category_id: The category to check for.
+
+ Returns:
+ A bool depending on whether the channel is in the category.
+ """
+ return getattr(channel, "category_id", None) == category_id
+
+
+async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel:
+ """
+ Attempt to get or fetch the given ``channel_id`` from the bots cache, and return it.
+
+ Args:
+ bot: The :obj:`discord.ext.commands.Bot` instance to use for getting/fetching.
+ channel_id: The channel to get/fetch.
+
+ Raises:
+ :exc:`discord.InvalidData`
+ An unknown channel type was received from Discord.
+ :exc:`discord.HTTPException`
+ Retrieving the channel failed.
+ :exc:`discord.NotFound`
+ Invalid Channel ID.
+ :exc:`discord.Forbidden`
+ You do not have permission to fetch this channel.
+
+ Returns:
+ The channel from the ID.
+ """
+ log.trace(f"Getting the channel {channel_id}.")
+
+ channel = bot.get_channel(channel_id)
+ if not channel:
+ log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
+ channel = await bot.fetch_channel(channel_id)
+
+ log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
+ return channel
diff --git a/botcore/utils/extensions.py b/botcore/utils/extensions.py
new file mode 100644
index 00000000..3f8d6e6d
--- /dev/null
+++ b/botcore/utils/extensions.py
@@ -0,0 +1,52 @@
+"""Utilities for loading discord extensions."""
+
+import importlib
+import inspect
+import pkgutil
+import types
+from typing import NoReturn
+
+
+def unqualify(name: str) -> str:
+ """
+ Return an unqualified name given a qualified module/package ``name``.
+
+ Args:
+ name: The module name to unqualify.
+
+ Returns:
+ The unqualified module name.
+ """
+ return name.rsplit(".", maxsplit=1)[-1]
+
+
+def walk_extensions(module: types.ModuleType) -> frozenset[str]:
+ """
+ Yield extension names from the given module.
+
+ Args:
+ module (types.ModuleType): The module to look for extensions in.
+
+ Returns:
+ A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`.
+ """
+
+ def on_error(name: str) -> NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ modules = set()
+
+ for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error):
+ if unqualify(module_info.name).startswith("_"):
+ # Ignore module/package names starting with an underscore.
+ continue
+
+ if module_info.ispkg:
+ imported = importlib.import_module(module_info.name)
+ if not inspect.isfunction(getattr(imported, "setup", None)):
+ # If it lacks a setup function, it's not an extension.
+ continue
+
+ modules.add(module_info.name)
+
+ return frozenset(modules)
diff --git a/botcore/utils/logging.py b/botcore/utils/logging.py
new file mode 100644
index 00000000..71ce4e66
--- /dev/null
+++ b/botcore/utils/logging.py
@@ -0,0 +1,45 @@
+"""Common logging related functions."""
+
+import logging
+import typing
+
+if typing.TYPE_CHECKING:
+ LoggerClass = logging.Logger
+else:
+ LoggerClass = logging.getLoggerClass()
+
+TRACE_LEVEL = 5
+
+
+class CustomLogger(LoggerClass):
+ """Custom implementation of the :obj:`logging.Logger` class with an added :obj:`trace` method."""
+
+ def trace(self, msg: str, *args, **kwargs) -> None:
+ """
+ Log the given message with the severity ``"TRACE"``.
+
+ To pass exception information, use the keyword argument exc_info with a true value:
+
+ .. code-block:: py
+
+ logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
+
+ Args:
+ msg: The message to be logged.
+ args, kwargs: Passed to the base log function as is.
+ """
+ if self.isEnabledFor(TRACE_LEVEL):
+ self.log(TRACE_LEVEL, msg, *args, **kwargs)
+
+
+def get_logger(name: typing.Optional[str] = None) -> CustomLogger:
+ """
+ Utility to make mypy recognise that logger is of type :obj:`CustomLogger`.
+
+ Args:
+ name: The name given to the logger.
+
+ Returns:
+ An instance of the :obj:`CustomLogger` class.
+ """
+ return typing.cast(CustomLogger, logging.getLogger(name))
diff --git a/botcore/utils/members.py b/botcore/utils/members.py
new file mode 100644
index 00000000..e89b4618
--- /dev/null
+++ b/botcore/utils/members.py
@@ -0,0 +1,56 @@
+"""Useful helper functions for interactin with :obj:`discord.Member` objects."""
+
+import typing
+
+import discord
+
+from botcore.utils import logging
+
+log = logging.get_logger(__name__)
+
+
+async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]:
+ """
+ Attempt to get a member from cache; on failure fetch from the API.
+
+ Returns:
+ The :obj:`discord.Member` or :obj:`None` to indicate the member could not be found.
+ """
+ if member := guild.get_member(member_id):
+ log.trace(f"{member} retrieved from cache.")
+ else:
+ try:
+ member = await guild.fetch_member(member_id)
+ except discord.errors.NotFound:
+ log.trace(f"Failed to fetch {member_id} from API.")
+ return None
+ log.trace(f"{member} fetched from API.")
+ return member
+
+
+async def handle_role_change(
+ member: discord.Member,
+ coro: typing.Callable[..., typing.Coroutine],
+ role: discord.Role
+) -> None:
+ """
+ Await the given ``coro`` with ``member`` as the sole argument.
+
+ Handle errors that we expect to be raised from
+ :obj:`discord.Member.add_roles` and :obj:`discord.Member.remove_roles`.
+
+ Args:
+ member: The member to pass to ``coro``.
+ coro: This is intended to be :obj:`discord.Member.add_roles` or :obj:`discord.Member.remove_roles`.
+ """
+ try:
+ await coro(role)
+ except discord.NotFound:
+ log.error(f"Failed to change role for {member} ({member.id}): member not found")
+ except discord.Forbidden:
+ log.error(
+ f"Forbidden to change role for {member} ({member.id}); "
+ f"possibly due to role hierarchy"
+ )
+ except discord.HTTPException as e:
+ log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}")
diff --git a/botcore/regex.py b/botcore/utils/regex.py
index 036a5113..036a5113 100644
--- a/botcore/regex.py
+++ b/botcore/utils/regex.py
diff --git a/botcore/utils/scheduling.py b/botcore/utils/scheduling.py
new file mode 100644
index 00000000..164f6b10
--- /dev/null
+++ b/botcore/utils/scheduling.py
@@ -0,0 +1,248 @@
+"""Generic python scheduler."""
+
+import asyncio
+import contextlib
+import inspect
+import typing
+from datetime import datetime
+from functools import partial
+
+from botcore.utils import logging
+
+
+class Scheduler:
+ """
+ Schedule the execution of coroutines and keep track of them.
+
+ When instantiating a :obj:`Scheduler`, a name must be provided. This name is used to distinguish the
+ instance's log messages from other instances. Using the name of the class or module containing
+ the instance is suggested.
+
+ Coroutines can be scheduled immediately with :obj:`schedule` or in the future with :obj:`schedule_at`
+ or :obj:`schedule_later`. A unique ID is required to be given in order to keep track of the
+ resulting Tasks. Any scheduled task can be cancelled prematurely using :obj:`cancel` by providing
+ the same ID used to schedule it.
+
+ The ``in`` operator is supported for checking if a task with a given ID is currently scheduled.
+
+ Any exception raised in a scheduled task is logged when the task is done.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize a new :obj:`Scheduler` instance.
+
+ Args:
+ name: The name of the :obj:`Scheduler`. Used in logging, and namespacing.
+ """
+ self.name = name
+
+ self._log = logging.get_logger(f"{__name__}.{name}")
+ self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {}
+
+ def __contains__(self, task_id: typing.Hashable) -> bool:
+ """
+ Return :obj:`True` if a task with the given ``task_id`` is currently scheduled.
+
+ Args:
+ task_id: The task to look for.
+
+ Returns:
+ :obj:`True` if the task was found.
+ """
+ return task_id in self._scheduled_tasks
+
+ def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule the execution of a ``coroutine``.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self._log.trace(f"Scheduling task #{task_id}...")
+
+ msg = f"Cannot schedule an already started coroutine for #{task_id}"
+ assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg
+
+ if task_id in self._scheduled_tasks:
+ self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.")
+ coroutine.close()
+ return
+
+ task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}")
+ task.add_done_callback(partial(self._task_done_callback, task_id))
+
+ self._scheduled_tasks[task_id] = task
+ self._log.debug(f"Scheduled task #{task_id} {id(task)}.")
+
+ def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule ``coroutine`` to be executed at the given ``time``.
+
+ If ``time`` is timezone aware, then use that timezone to calculate now() when subtracting.
+ If ``time`` is naïve, then use UTC.
+
+ If ``time`` is in the past, schedule ``coroutine`` immediately.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ time: The time to start the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow()
+ delay = (time - now_datetime).total_seconds()
+ if delay > 0:
+ coroutine = self._await_later(delay, task_id, coroutine)
+
+ self.schedule(task_id, coroutine)
+
+ def schedule_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """
+ Schedule ``coroutine`` to be executed after ``delay`` seconds.
+
+ If a task with ``task_id`` already exists, close ``coroutine`` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ delay: How long to wait before starting the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self.schedule(task_id, self._await_later(delay, task_id, coroutine))
+
+ def cancel(self, task_id: typing.Hashable) -> None:
+ """
+ Unschedule the task identified by ``task_id``. Log a warning if the task doesn't exist.
+
+ Args:
+ task_id: The task's unique ID.
+ """
+ self._log.trace(f"Cancelling task #{task_id}...")
+
+ try:
+ task = self._scheduled_tasks.pop(task_id)
+ except KeyError:
+ self._log.warning(f"Failed to unschedule {task_id} (no task found).")
+ else:
+ task.cancel()
+
+ self._log.debug(f"Unscheduled task #{task_id} {id(task)}.")
+
+ def cancel_all(self) -> None:
+ """Unschedule all known tasks."""
+ self._log.debug("Unscheduling all tasks")
+
+ for task_id in self._scheduled_tasks.copy():
+ self.cancel(task_id)
+
+ async def _await_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """Await ``coroutine`` after ``delay`` seconds."""
+ try:
+ self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.")
+ await asyncio.sleep(delay)
+
+ # Use asyncio.shield to prevent the coroutine from cancelling itself.
+ self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.")
+ await asyncio.shield(coroutine)
+ finally:
+ # Close it to prevent unawaited coroutine warnings,
+ # which would happen if the task was cancelled during the sleep.
+ # Only close it if it's not been awaited yet. This check is important because the
+ # coroutine may cancel this task, which would also trigger the finally block.
+ state = inspect.getcoroutinestate(coroutine)
+ if state == "CORO_CREATED":
+ self._log.debug(f"Explicitly closing the coroutine for #{task_id}.")
+ coroutine.close()
+ else:
+ self._log.debug(f"Finally block reached for #{task_id}; {state=}")
+
+ def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None:
+ """
+ Delete the task and raise its exception if one exists.
+
+ If ``done_task`` and the task associated with ``task_id`` are different, then the latter
+ will not be deleted. In this case, a new task was likely rescheduled with the same ID.
+ """
+ self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.")
+
+ scheduled_task = self._scheduled_tasks.get(task_id)
+
+ if scheduled_task and done_task is scheduled_task:
+ # A task for the ID exists and is the same as the done task.
+ # Since this is the done callback, the task is already done so no need to cancel it.
+ self._log.trace(f"Deleting task #{task_id} {id(done_task)}.")
+ del self._scheduled_tasks[task_id]
+ elif scheduled_task:
+ # A new task was likely rescheduled with the same ID.
+ self._log.debug(
+ f"The scheduled task #{task_id} {id(scheduled_task)} "
+ f"and the done task {id(done_task)} differ."
+ )
+ elif not done_task.cancelled():
+ self._log.warning(
+ f"Task #{task_id} not found while handling task {id(done_task)}! "
+ f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."
+ )
+
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = done_task.exception()
+ # Log the exception if one exists.
+ if exception:
+ self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
+
+
+def create_task(
+ coro: typing.Awaitable,
+ *,
+ suppressed_exceptions: tuple[typing.Type[Exception]] = (),
+ event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
+ **kwargs,
+) -> asyncio.Task:
+ """
+ Wrapper for creating an :obj:`asyncio.Task` which logs exceptions raised in the task.
+
+ If the ``event_loop`` kwarg is provided, the task is created from that event loop,
+ otherwise the running loop is used.
+
+ Args:
+ coro: The function to call.
+ suppressed_exceptions: Exceptions to be handled by the task.
+ event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from.
+ kwargs: Passed to :py:func:`asyncio.create_task`.
+
+ Returns:
+ asyncio.Task: The wrapped task.
+ """
+ if event_loop is not None:
+ task = event_loop.create_task(coro, **kwargs)
+ else:
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
+ return task
+
+
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None:
+ """Retrieve and log the exception raised in ``task`` if one exists."""
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = task.exception()
+ # Log the exception if one exists.
+ if exception and not isinstance(exception, suppressed_exceptions):
+ log = logging.get_logger(__name__)
+ log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)