diff options
Diffstat (limited to 'botcore')
| -rw-r--r-- | botcore/__init__.py | 13 | ||||
| -rw-r--r-- | botcore/caching.py | 65 | ||||
| -rw-r--r-- | botcore/channel.py | 26 | ||||
| -rw-r--r-- | botcore/extensions.py | 52 | ||||
| -rw-r--r-- | botcore/exts/__init__.py | 4 | ||||
| -rw-r--r-- | botcore/loggers.py | 45 | ||||
| -rw-r--r-- | botcore/members.py | 48 | ||||
| -rw-r--r-- | botcore/scheduling.py | 246 | 
8 files changed, 496 insertions, 3 deletions
| diff --git a/botcore/__init__.py b/botcore/__init__.py index c582d0df..f32e6bf2 100644 --- a/botcore/__init__.py +++ b/botcore/__init__.py @@ -1,9 +1,16 @@ -from botcore import ( -    regex, -) +"""Useful utilities and tools for discord bot development.""" + +from botcore import (caching, channel, extensions, exts, loggers, members, regex, scheduling)  __all__ = [ +    caching, +    channel, +    extensions, +    exts, +    loggers, +    members,      regex, +    scheduling,  ]  __all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/caching.py b/botcore/caching.py new file mode 100644 index 00000000..ea71ed1d --- /dev/null +++ b/botcore/caching.py @@ -0,0 +1,65 @@ +"""Utilities related to custom caches.""" + +import functools +import typing +from collections import OrderedDict + + +class AsyncCache: +    """ +    LRU cache implementation for coroutines. + +    Once the cache exceeds the maximum size, keys are deleted in FIFO order. + +    An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. +    """ + +    def __init__(self, max_size: int = 128): +        """ +        Initialise a new AsyncCache instance. + +        Args: +            max_size: How many items to store in the cache. +        """ +        self._cache = OrderedDict() +        self._max_size = max_size + +    def __call__(self, arg_offset: int = 0) -> typing.Callable: +        """ +        Decorator for async cache. + +        Args: +            arg_offset: The offset for the position of the key argument. + +        Returns: +            A decorator to wrap the target function. +        """ + +        def decorator(function: typing.Callable) -> typing.Callable: +            """ +            Define the async cache decorator. + +            Args: +                function: The function to wrap. + +            Returns: +                The wrapped function. +            """ + +            @functools.wraps(function) +            async def wrapper(*args) -> typing.Any: +                """Decorator wrapper for the caching logic.""" +                key = args[arg_offset:] + +                if key not in self._cache: +                    if len(self._cache) > self._max_size: +                        self._cache.popitem(last=False) + +                    self._cache[key] = await function(*args) +                return self._cache[key] +            return wrapper +        return decorator + +    def clear(self) -> None: +        """Clear cache instance.""" +        self._cache.clear() diff --git a/botcore/channel.py b/botcore/channel.py new file mode 100644 index 00000000..b19b4f08 --- /dev/null +++ b/botcore/channel.py @@ -0,0 +1,26 @@ +"""Utilities for interacting with discord channels.""" + +import discord +from discord.ext.commands import Bot + +from botcore import loggers + +log = loggers.get_logger(__name__) + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: +    """Return True if `channel` is within a category with `category_id`.""" +    return getattr(channel, "category_id", None) == category_id + + +async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel: +    """Attempt to get or fetch a channel and return it.""" +    log.trace(f"Getting the channel {channel_id}.") + +    channel = bot.get_channel(channel_id) +    if not channel: +        log.debug(f"Channel {channel_id} is not in cache; fetching from API.") +        channel = await bot.fetch_channel(channel_id) + +    log.trace(f"Channel #{channel} ({channel_id}) retrieved.") +    return channel diff --git a/botcore/extensions.py b/botcore/extensions.py new file mode 100644 index 00000000..c8f200ad --- /dev/null +++ b/botcore/extensions.py @@ -0,0 +1,52 @@ +"""Utilities for loading discord extensions.""" + +import importlib +import inspect +import pkgutil +import types +from typing import NoReturn + + +def unqualify(name: str) -> str: +    """ +    Return an unqualified name given a qualified module/package `name`. + +    Args: +        name: The module name to unqualify. + +    Returns: +        The unqualified module name. +    """ +    return name.rsplit(".", maxsplit=1)[-1] + + +def walk_extensions(module: types.ModuleType) -> frozenset[str]: +    """ +    Yield extension names from the bot.exts subpackage. + +    Args: +        module (types.ModuleType): The module to look for extensions in. + +    Returns: +        A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`. +    """ + +    def on_error(name: str) -> NoReturn: +        raise ImportError(name=name)  # pragma: no cover + +    modules = set() + +    for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error): +        if unqualify(module_info.name).startswith("_"): +            # Ignore module/package names starting with an underscore. +            continue + +        if module_info.ispkg: +            imported = importlib.import_module(module_info.name) +            if not inspect.isfunction(getattr(imported, "setup", None)): +                # If it lacks a setup function, it's not an extension. +                continue + +        modules.add(module_info.name) + +    return frozenset(modules) diff --git a/botcore/exts/__init__.py b/botcore/exts/__init__.py new file mode 100644 index 00000000..029178a9 --- /dev/null +++ b/botcore/exts/__init__.py @@ -0,0 +1,4 @@ +"""Reusable discord cogs.""" +__all__ = [] + +__all__ = list(map(lambda module: module.__name__, __all__)) diff --git a/botcore/loggers.py b/botcore/loggers.py new file mode 100644 index 00000000..ac1db920 --- /dev/null +++ b/botcore/loggers.py @@ -0,0 +1,45 @@ +"""Custom logging class.""" + +import logging +import typing + +if typing.TYPE_CHECKING: +    LoggerClass = logging.Logger +else: +    LoggerClass = logging.getLoggerClass() + +TRACE_LEVEL = 5 + + +class CustomLogger(LoggerClass): +    """Custom implementation of the `Logger` class with an added `trace` method.""" + +    def trace(self, msg: str, *args, **kwargs) -> None: +        """ +        Log 'msg % args' with severity 'TRACE'. + +        To pass exception information, use the keyword argument exc_info with a true value: + +        .. code-block:: py + +            logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) + +        Args: +            msg: The message to be logged. +            args, kwargs: Passed to the base log function as is. +        """ +        if self.isEnabledFor(TRACE_LEVEL): +            self.log(TRACE_LEVEL, msg, *args, **kwargs) + + +def get_logger(name: typing.Optional[str] = None) -> CustomLogger: +    """ +    Utility to make mypy recognise that logger is of type `CustomLogger`. + +    Args: +        name: The name given to the logger. + +    Returns: +        An instance of the `CustomLogger` class. +    """ +    return typing.cast(CustomLogger, logging.getLogger(name)) diff --git a/botcore/members.py b/botcore/members.py new file mode 100644 index 00000000..07b16ea3 --- /dev/null +++ b/botcore/members.py @@ -0,0 +1,48 @@ +import typing + +import discord + +from botcore.loggers import get_logger + +log = get_logger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]: +    """ +    Attempt to get a member from cache; on failure fetch from the API. + +    Return `None` to indicate the member could not be found. +    """ +    if member := guild.get_member(member_id): +        log.trace(f"{member} retrieved from cache.") +    else: +        try: +            member = await guild.fetch_member(member_id) +        except discord.errors.NotFound: +            log.trace(f"Failed to fetch {member_id} from API.") +            return None +        log.trace(f"{member} fetched from API.") +    return member + + +async def handle_role_change( +    member: discord.Member, +    coro: typing.Callable[..., typing.Coroutine], +    role: discord.Role +) -> None: +    """ +    Change `member`'s cooldown role via awaiting `coro` and handle errors. + +    `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. +    """ +    try: +        await coro(role) +    except discord.NotFound: +        log.error(f"Failed to change role for {member} ({member.id}): member not found") +    except discord.Forbidden: +        log.error( +            f"Forbidden to change role for {member} ({member.id}); " +            f"possibly due to role hierarchy" +        ) +    except discord.HTTPException as e: +        log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/botcore/scheduling.py b/botcore/scheduling.py new file mode 100644 index 00000000..206e5e79 --- /dev/null +++ b/botcore/scheduling.py @@ -0,0 +1,246 @@ +"""Generic python scheduler.""" + +import asyncio +import contextlib +import inspect +import typing +from datetime import datetime +from functools import partial + +from botcore.loggers import get_logger + + +class Scheduler: +    """ +    Schedule the execution of coroutines and keep track of them. + +    When instantiating a Scheduler, a name must be provided. This name is used to distinguish the +    instance's log messages from other instances. Using the name of the class or module containing +    the instance is suggested. + +    Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at` +    or `schedule_later`. A unique ID is required to be given in order to keep track of the +    resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing +    the same ID used to schedule it.  The `in` operator is supported for checking if a task with a +    given ID is currently scheduled. + +    Any exception raised in a scheduled task is logged when the task is done. +    """ + +    def __init__(self, name: str): +        """ +        Initialize a new Scheduler instance. + +        Args: +            name: The name of the scheduler. Used in logging, and namespacing. +        """ +        self.name = name + +        self._log = get_logger(f"{__name__}.{name}") +        self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {} + +    def __contains__(self, task_id: typing.Hashable) -> bool: +        """ +        Return True if a task with the given `task_id` is currently scheduled. + +        Args: +            task_id: The task to look for. + +        Returns: +            True if the task was found. +        """ +        return task_id in self._scheduled_tasks + +    def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: +        """ +        Schedule the execution of a `coroutine`. + +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + +        Args: +            task_id: A unique ID to create the task with. +            coroutine: The function to be called. +        """ +        self._log.trace(f"Scheduling task #{task_id}...") + +        msg = f"Cannot schedule an already started coroutine for #{task_id}" +        assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg + +        if task_id in self._scheduled_tasks: +            self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.") +            coroutine.close() +            return + +        task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") +        task.add_done_callback(partial(self._task_done_callback, task_id)) + +        self._scheduled_tasks[task_id] = task +        self._log.debug(f"Scheduled task #{task_id} {id(task)}.") + +    def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None: +        """ +        Schedule `coroutine` to be executed at the given `time`. + +        If `time` is timezone aware, then use that timezone to calculate now() when subtracting. +        If `time` is naïve, then use UTC. + +        If `time` is in the past, schedule `coroutine` immediately. + +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + +        Args: +            time: The time to start the task. +            task_id: A unique ID to create the task with. +            coroutine: The function to be called. +        """ +        now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow() +        delay = (time - now_datetime).total_seconds() +        if delay > 0: +            coroutine = self._await_later(delay, task_id, coroutine) + +        self.schedule(task_id, coroutine) + +    def schedule_later( +        self, +        delay: typing.Union[int, float], +        task_id: typing.Hashable, +        coroutine: typing.Coroutine +    ) -> None: +        """ +        Schedule `coroutine` to be executed after `delay` seconds. + +        If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This +        prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere. + +        Args: +            delay: How long to wait before starting the task. +            task_id: A unique ID to create the task with. +            coroutine: The function to be called. +        """ +        self.schedule(task_id, self._await_later(delay, task_id, coroutine)) + +    def cancel(self, task_id: typing.Hashable) -> None: +        """ +        Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist. + +        Args: +            task_id: The task's unique ID. +        """ +        self._log.trace(f"Cancelling task #{task_id}...") + +        try: +            task = self._scheduled_tasks.pop(task_id) +        except KeyError: +            self._log.warning(f"Failed to unschedule {task_id} (no task found).") +        else: +            task.cancel() + +            self._log.debug(f"Unscheduled task #{task_id} {id(task)}.") + +    def cancel_all(self) -> None: +        """Unschedule all known tasks.""" +        self._log.debug("Unscheduling all tasks") + +        for task_id in self._scheduled_tasks.copy(): +            self.cancel(task_id) + +    async def _await_later( +        self, +        delay: typing.Union[int, float], +        task_id: typing.Hashable, +        coroutine: typing.Coroutine +    ) -> None: +        """Await `coroutine` after `delay` seconds.""" +        try: +            self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.") +            await asyncio.sleep(delay) + +            # Use asyncio.shield to prevent the coroutine from cancelling itself. +            self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.") +            await asyncio.shield(coroutine) +        finally: +            # Close it to prevent unawaited coroutine warnings, +            # which would happen if the task was cancelled during the sleep. +            # Only close it if it's not been awaited yet. This check is important because the +            # coroutine may cancel this task, which would also trigger the finally block. +            state = inspect.getcoroutinestate(coroutine) +            if state == "CORO_CREATED": +                self._log.debug(f"Explicitly closing the coroutine for #{task_id}.") +                coroutine.close() +            else: +                self._log.debug(f"Finally block reached for #{task_id}; {state=}") + +    def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None: +        """ +        Delete the task and raise its exception if one exists. + +        If `done_task` and the task associated with `task_id` are different, then the latter +        will not be deleted. In this case, a new task was likely rescheduled with the same ID. +        """ +        self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.") + +        scheduled_task = self._scheduled_tasks.get(task_id) + +        if scheduled_task and done_task is scheduled_task: +            # A task for the ID exists and is the same as the done task. +            # Since this is the done callback, the task is already done so no need to cancel it. +            self._log.trace(f"Deleting task #{task_id} {id(done_task)}.") +            del self._scheduled_tasks[task_id] +        elif scheduled_task: +            # A new task was likely rescheduled with the same ID. +            self._log.debug( +                f"The scheduled task #{task_id} {id(scheduled_task)} " +                f"and the done task {id(done_task)} differ." +            ) +        elif not done_task.cancelled(): +            self._log.warning( +                f"Task #{task_id} not found while handling task {id(done_task)}! " +                f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." +            ) + +        with contextlib.suppress(asyncio.CancelledError): +            exception = done_task.exception() +            # Log the exception if one exists. +            if exception: +                self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) + + +def create_task( +    coro: typing.Awaitable, +    *, +    suppressed_exceptions: tuple[typing.Type[Exception]] = (), +    event_loop: typing.Optional[asyncio.AbstractEventLoop] = None, +    **kwargs, +) -> asyncio.Task: +    """ +    Wrapper for creating asyncio `Tasks` which logs exceptions raised in the task. + +    If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used. + +    Args: +        coro: The function to call. +        suppressed_exceptions: Exceptions to be handled by the task. +        event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from. +        kwargs: Passed to :py:func:`asyncio.create_task`. + +    Returns: +        asyncio.Task: The wrapped task. +    """ +    if event_loop is not None: +        task = event_loop.create_task(coro, **kwargs) +    else: +        task = asyncio.create_task(coro, **kwargs) +    task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) +    return task + + +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None: +    """Retrieve and log the exception raised in `task` if one exists.""" +    with contextlib.suppress(asyncio.CancelledError): +        exception = task.exception() +        # Log the exception if one exists. +        if exception and not isinstance(exception, suppressed_exceptions): +            log = get_logger(__name__) +            log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) | 
