aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--botcore/__init__.py13
-rw-r--r--botcore/caching.py65
-rw-r--r--botcore/channel.py26
-rw-r--r--botcore/extensions.py52
-rw-r--r--botcore/exts/__init__.py4
-rw-r--r--botcore/loggers.py45
-rw-r--r--botcore/members.py48
-rw-r--r--botcore/scheduling.py246
-rw-r--r--docs/changelog.rst3
-rw-r--r--docs/conf.py128
-rw-r--r--docs/utils.py117
-rw-r--r--pyproject.toml2
-rw-r--r--tox.ini2
13 files changed, 640 insertions, 111 deletions
diff --git a/botcore/__init__.py b/botcore/__init__.py
index c582d0df..f32e6bf2 100644
--- a/botcore/__init__.py
+++ b/botcore/__init__.py
@@ -1,9 +1,16 @@
-from botcore import (
- regex,
-)
+"""Useful utilities and tools for discord bot development."""
+
+from botcore import (caching, channel, extensions, exts, loggers, members, regex, scheduling)
__all__ = [
+ caching,
+ channel,
+ extensions,
+ exts,
+ loggers,
+ members,
regex,
+ scheduling,
]
__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/caching.py b/botcore/caching.py
new file mode 100644
index 00000000..ea71ed1d
--- /dev/null
+++ b/botcore/caching.py
@@ -0,0 +1,65 @@
+"""Utilities related to custom caches."""
+
+import functools
+import typing
+from collections import OrderedDict
+
+
+class AsyncCache:
+ """
+ LRU cache implementation for coroutines.
+
+ Once the cache exceeds the maximum size, keys are deleted in FIFO order.
+
+ An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key.
+ """
+
+ def __init__(self, max_size: int = 128):
+ """
+ Initialise a new AsyncCache instance.
+
+ Args:
+ max_size: How many items to store in the cache.
+ """
+ self._cache = OrderedDict()
+ self._max_size = max_size
+
+ def __call__(self, arg_offset: int = 0) -> typing.Callable:
+ """
+ Decorator for async cache.
+
+ Args:
+ arg_offset: The offset for the position of the key argument.
+
+ Returns:
+ A decorator to wrap the target function.
+ """
+
+ def decorator(function: typing.Callable) -> typing.Callable:
+ """
+ Define the async cache decorator.
+
+ Args:
+ function: The function to wrap.
+
+ Returns:
+ The wrapped function.
+ """
+
+ @functools.wraps(function)
+ async def wrapper(*args) -> typing.Any:
+ """Decorator wrapper for the caching logic."""
+ key = args[arg_offset:]
+
+ if key not in self._cache:
+ if len(self._cache) > self._max_size:
+ self._cache.popitem(last=False)
+
+ self._cache[key] = await function(*args)
+ return self._cache[key]
+ return wrapper
+ return decorator
+
+ def clear(self) -> None:
+ """Clear cache instance."""
+ self._cache.clear()
diff --git a/botcore/channel.py b/botcore/channel.py
new file mode 100644
index 00000000..b19b4f08
--- /dev/null
+++ b/botcore/channel.py
@@ -0,0 +1,26 @@
+"""Utilities for interacting with discord channels."""
+
+import discord
+from discord.ext.commands import Bot
+
+from botcore import loggers
+
+log = loggers.get_logger(__name__)
+
+
+def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
+ """Return True if `channel` is within a category with `category_id`."""
+ return getattr(channel, "category_id", None) == category_id
+
+
+async def get_or_fetch_channel(bot: Bot, channel_id: int) -> discord.abc.GuildChannel:
+ """Attempt to get or fetch a channel and return it."""
+ log.trace(f"Getting the channel {channel_id}.")
+
+ channel = bot.get_channel(channel_id)
+ if not channel:
+ log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
+ channel = await bot.fetch_channel(channel_id)
+
+ log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
+ return channel
diff --git a/botcore/extensions.py b/botcore/extensions.py
new file mode 100644
index 00000000..c8f200ad
--- /dev/null
+++ b/botcore/extensions.py
@@ -0,0 +1,52 @@
+"""Utilities for loading discord extensions."""
+
+import importlib
+import inspect
+import pkgutil
+import types
+from typing import NoReturn
+
+
+def unqualify(name: str) -> str:
+ """
+ Return an unqualified name given a qualified module/package `name`.
+
+ Args:
+ name: The module name to unqualify.
+
+ Returns:
+ The unqualified module name.
+ """
+ return name.rsplit(".", maxsplit=1)[-1]
+
+
+def walk_extensions(module: types.ModuleType) -> frozenset[str]:
+ """
+ Yield extension names from the bot.exts subpackage.
+
+ Args:
+ module (types.ModuleType): The module to look for extensions in.
+
+ Returns:
+ A set of strings that can be passed directly to :obj:`discord.ext.commands.Bot.load_extension`.
+ """
+
+ def on_error(name: str) -> NoReturn:
+ raise ImportError(name=name) # pragma: no cover
+
+ modules = set()
+
+ for module_info in pkgutil.walk_packages(module.__path__, f"{module.__name__}.", onerror=on_error):
+ if unqualify(module_info.name).startswith("_"):
+ # Ignore module/package names starting with an underscore.
+ continue
+
+ if module_info.ispkg:
+ imported = importlib.import_module(module_info.name)
+ if not inspect.isfunction(getattr(imported, "setup", None)):
+ # If it lacks a setup function, it's not an extension.
+ continue
+
+ modules.add(module_info.name)
+
+ return frozenset(modules)
diff --git a/botcore/exts/__init__.py b/botcore/exts/__init__.py
new file mode 100644
index 00000000..029178a9
--- /dev/null
+++ b/botcore/exts/__init__.py
@@ -0,0 +1,4 @@
+"""Reusable discord cogs."""
+__all__ = []
+
+__all__ = list(map(lambda module: module.__name__, __all__))
diff --git a/botcore/loggers.py b/botcore/loggers.py
new file mode 100644
index 00000000..ac1db920
--- /dev/null
+++ b/botcore/loggers.py
@@ -0,0 +1,45 @@
+"""Custom logging class."""
+
+import logging
+import typing
+
+if typing.TYPE_CHECKING:
+ LoggerClass = logging.Logger
+else:
+ LoggerClass = logging.getLoggerClass()
+
+TRACE_LEVEL = 5
+
+
+class CustomLogger(LoggerClass):
+ """Custom implementation of the `Logger` class with an added `trace` method."""
+
+ def trace(self, msg: str, *args, **kwargs) -> None:
+ """
+ Log 'msg % args' with severity 'TRACE'.
+
+ To pass exception information, use the keyword argument exc_info with a true value:
+
+ .. code-block:: py
+
+ logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
+
+ Args:
+ msg: The message to be logged.
+ args, kwargs: Passed to the base log function as is.
+ """
+ if self.isEnabledFor(TRACE_LEVEL):
+ self.log(TRACE_LEVEL, msg, *args, **kwargs)
+
+
+def get_logger(name: typing.Optional[str] = None) -> CustomLogger:
+ """
+ Utility to make mypy recognise that logger is of type `CustomLogger`.
+
+ Args:
+ name: The name given to the logger.
+
+ Returns:
+ An instance of the `CustomLogger` class.
+ """
+ return typing.cast(CustomLogger, logging.getLogger(name))
diff --git a/botcore/members.py b/botcore/members.py
new file mode 100644
index 00000000..07b16ea3
--- /dev/null
+++ b/botcore/members.py
@@ -0,0 +1,48 @@
+import typing
+
+import discord
+
+from botcore.loggers import get_logger
+
+log = get_logger(__name__)
+
+
+async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> typing.Optional[discord.Member]:
+ """
+ Attempt to get a member from cache; on failure fetch from the API.
+
+ Return `None` to indicate the member could not be found.
+ """
+ if member := guild.get_member(member_id):
+ log.trace(f"{member} retrieved from cache.")
+ else:
+ try:
+ member = await guild.fetch_member(member_id)
+ except discord.errors.NotFound:
+ log.trace(f"Failed to fetch {member_id} from API.")
+ return None
+ log.trace(f"{member} fetched from API.")
+ return member
+
+
+async def handle_role_change(
+ member: discord.Member,
+ coro: typing.Callable[..., typing.Coroutine],
+ role: discord.Role
+) -> None:
+ """
+ Change `member`'s cooldown role via awaiting `coro` and handle errors.
+
+ `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`.
+ """
+ try:
+ await coro(role)
+ except discord.NotFound:
+ log.error(f"Failed to change role for {member} ({member.id}): member not found")
+ except discord.Forbidden:
+ log.error(
+ f"Forbidden to change role for {member} ({member.id}); "
+ f"possibly due to role hierarchy"
+ )
+ except discord.HTTPException as e:
+ log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}")
diff --git a/botcore/scheduling.py b/botcore/scheduling.py
new file mode 100644
index 00000000..206e5e79
--- /dev/null
+++ b/botcore/scheduling.py
@@ -0,0 +1,246 @@
+"""Generic python scheduler."""
+
+import asyncio
+import contextlib
+import inspect
+import typing
+from datetime import datetime
+from functools import partial
+
+from botcore.loggers import get_logger
+
+
+class Scheduler:
+ """
+ Schedule the execution of coroutines and keep track of them.
+
+ When instantiating a Scheduler, a name must be provided. This name is used to distinguish the
+ instance's log messages from other instances. Using the name of the class or module containing
+ the instance is suggested.
+
+ Coroutines can be scheduled immediately with `schedule` or in the future with `schedule_at`
+ or `schedule_later`. A unique ID is required to be given in order to keep track of the
+ resulting Tasks. Any scheduled task can be cancelled prematurely using `cancel` by providing
+ the same ID used to schedule it. The `in` operator is supported for checking if a task with a
+ given ID is currently scheduled.
+
+ Any exception raised in a scheduled task is logged when the task is done.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize a new Scheduler instance.
+
+ Args:
+ name: The name of the scheduler. Used in logging, and namespacing.
+ """
+ self.name = name
+
+ self._log = get_logger(f"{__name__}.{name}")
+ self._scheduled_tasks: typing.Dict[typing.Hashable, asyncio.Task] = {}
+
+ def __contains__(self, task_id: typing.Hashable) -> bool:
+ """
+ Return True if a task with the given `task_id` is currently scheduled.
+
+ Args:
+ task_id: The task to look for.
+
+ Returns:
+ True if the task was found.
+ """
+ return task_id in self._scheduled_tasks
+
+ def schedule(self, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule the execution of a `coroutine`.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self._log.trace(f"Scheduling task #{task_id}...")
+
+ msg = f"Cannot schedule an already started coroutine for #{task_id}"
+ assert inspect.getcoroutinestate(coroutine) == "CORO_CREATED", msg
+
+ if task_id in self._scheduled_tasks:
+ self._log.debug(f"Did not schedule task #{task_id}; task was already scheduled.")
+ coroutine.close()
+ return
+
+ task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}")
+ task.add_done_callback(partial(self._task_done_callback, task_id))
+
+ self._scheduled_tasks[task_id] = task
+ self._log.debug(f"Scheduled task #{task_id} {id(task)}.")
+
+ def schedule_at(self, time: datetime, task_id: typing.Hashable, coroutine: typing.Coroutine) -> None:
+ """
+ Schedule `coroutine` to be executed at the given `time`.
+
+ If `time` is timezone aware, then use that timezone to calculate now() when subtracting.
+ If `time` is naïve, then use UTC.
+
+ If `time` is in the past, schedule `coroutine` immediately.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ time: The time to start the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ now_datetime = datetime.now(time.tzinfo) if time.tzinfo else datetime.utcnow()
+ delay = (time - now_datetime).total_seconds()
+ if delay > 0:
+ coroutine = self._await_later(delay, task_id, coroutine)
+
+ self.schedule(task_id, coroutine)
+
+ def schedule_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """
+ Schedule `coroutine` to be executed after `delay` seconds.
+
+ If a task with `task_id` already exists, close `coroutine` instead of scheduling it. This
+ prevents unawaited coroutine warnings. Don't pass a coroutine that'll be re-used elsewhere.
+
+ Args:
+ delay: How long to wait before starting the task.
+ task_id: A unique ID to create the task with.
+ coroutine: The function to be called.
+ """
+ self.schedule(task_id, self._await_later(delay, task_id, coroutine))
+
+ def cancel(self, task_id: typing.Hashable) -> None:
+ """
+ Unschedule the task identified by `task_id`. Log a warning if the task doesn't exist.
+
+ Args:
+ task_id: The task's unique ID.
+ """
+ self._log.trace(f"Cancelling task #{task_id}...")
+
+ try:
+ task = self._scheduled_tasks.pop(task_id)
+ except KeyError:
+ self._log.warning(f"Failed to unschedule {task_id} (no task found).")
+ else:
+ task.cancel()
+
+ self._log.debug(f"Unscheduled task #{task_id} {id(task)}.")
+
+ def cancel_all(self) -> None:
+ """Unschedule all known tasks."""
+ self._log.debug("Unscheduling all tasks")
+
+ for task_id in self._scheduled_tasks.copy():
+ self.cancel(task_id)
+
+ async def _await_later(
+ self,
+ delay: typing.Union[int, float],
+ task_id: typing.Hashable,
+ coroutine: typing.Coroutine
+ ) -> None:
+ """Await `coroutine` after `delay` seconds."""
+ try:
+ self._log.trace(f"Waiting {delay} seconds before awaiting coroutine for #{task_id}.")
+ await asyncio.sleep(delay)
+
+ # Use asyncio.shield to prevent the coroutine from cancelling itself.
+ self._log.trace(f"Done waiting for #{task_id}; now awaiting the coroutine.")
+ await asyncio.shield(coroutine)
+ finally:
+ # Close it to prevent unawaited coroutine warnings,
+ # which would happen if the task was cancelled during the sleep.
+ # Only close it if it's not been awaited yet. This check is important because the
+ # coroutine may cancel this task, which would also trigger the finally block.
+ state = inspect.getcoroutinestate(coroutine)
+ if state == "CORO_CREATED":
+ self._log.debug(f"Explicitly closing the coroutine for #{task_id}.")
+ coroutine.close()
+ else:
+ self._log.debug(f"Finally block reached for #{task_id}; {state=}")
+
+ def _task_done_callback(self, task_id: typing.Hashable, done_task: asyncio.Task) -> None:
+ """
+ Delete the task and raise its exception if one exists.
+
+ If `done_task` and the task associated with `task_id` are different, then the latter
+ will not be deleted. In this case, a new task was likely rescheduled with the same ID.
+ """
+ self._log.trace(f"Performing done callback for task #{task_id} {id(done_task)}.")
+
+ scheduled_task = self._scheduled_tasks.get(task_id)
+
+ if scheduled_task and done_task is scheduled_task:
+ # A task for the ID exists and is the same as the done task.
+ # Since this is the done callback, the task is already done so no need to cancel it.
+ self._log.trace(f"Deleting task #{task_id} {id(done_task)}.")
+ del self._scheduled_tasks[task_id]
+ elif scheduled_task:
+ # A new task was likely rescheduled with the same ID.
+ self._log.debug(
+ f"The scheduled task #{task_id} {id(scheduled_task)} "
+ f"and the done task {id(done_task)} differ."
+ )
+ elif not done_task.cancelled():
+ self._log.warning(
+ f"Task #{task_id} not found while handling task {id(done_task)}! "
+ f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."
+ )
+
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = done_task.exception()
+ # Log the exception if one exists.
+ if exception:
+ self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
+
+
+def create_task(
+ coro: typing.Awaitable,
+ *,
+ suppressed_exceptions: tuple[typing.Type[Exception]] = (),
+ event_loop: typing.Optional[asyncio.AbstractEventLoop] = None,
+ **kwargs,
+) -> asyncio.Task:
+ """
+ Wrapper for creating asyncio `Tasks` which logs exceptions raised in the task.
+
+ If the loop kwarg is provided, the task is created from that event loop, otherwise the running loop is used.
+
+ Args:
+ coro: The function to call.
+ suppressed_exceptions: Exceptions to be handled by the task.
+ event_loop (:obj:`asyncio.AbstractEventLoop`): The loop to create the task from.
+ kwargs: Passed to :py:func:`asyncio.create_task`.
+
+ Returns:
+ asyncio.Task: The wrapped task.
+ """
+ if event_loop is not None:
+ task = event_loop.create_task(coro, **kwargs)
+ else:
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
+ return task
+
+
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: typing.Tuple[typing.Type[Exception]]) -> None:
+ """Retrieve and log the exception raised in `task` if one exists."""
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = task.exception()
+ # Log the exception if one exists.
+ if exception and not isinstance(exception, suppressed_exceptions):
+ log = get_logger(__name__)
+ log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 743fcc20..25d01756 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -4,7 +4,8 @@
Changelog
=========
-
+- :release:`1.2.1 <22nd February 2022>`
+- :support:`3` Added intersphinx to docs.
- :release:`1.2.0 <9th January 2022>`
- :feature:`12` Code block detection regex
- :release:`1.1.0 <2nd December 2021>`
diff --git a/docs/conf.py b/docs/conf.py
index 4ab831d3..e2801e2c 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,16 +1,14 @@
# Configuration file for the Sphinx documentation builder.
# https://www.sphinx-doc.org/en/master/usage/configuration.html
-import ast
-import importlib
-import inspect
+import functools
import sys
-import typing
from pathlib import Path
import git
import tomli
-from sphinx.application import Sphinx
+
+from docs import utils
# -- Project information -----------------------------------------------------
@@ -38,6 +36,7 @@ add_module_names = False
# ones.
extensions = [
"sphinx.ext.extlinks",
+ "sphinx.ext.intersphinx",
"sphinx.ext.autodoc",
"sphinx.ext.todo",
"sphinx.ext.napoleon",
@@ -89,45 +88,6 @@ html_js_files = [
]
-# -- Autodoc cleanup ---------------------------------------------------------
-# Clean up the output generated by autodoc to produce a nicer documentation tree
-# This is kept in a function to avoid polluting the namespace
-def __cleanup() -> None:
- for file in (PROJECT_ROOT / "docs" / "output").iterdir():
- if file.name == "modules.rst":
- # We only have one module, so this is redundant
- # Remove it and flatten out the tree
- file.unlink()
-
- elif file.name == "botcore.rst":
- # We want to bring the submodule name to the top, and remove anything that's not a submodule
- result = ""
- for line in file.read_text(encoding="utf-8").splitlines(keepends=True):
- if ".." not in line and result == "":
- # We have not reached the first submodule, this is all filler
- continue
- elif "Module contents" in line:
- # We have parsed all the submodules, so let's skip the redudant module name
- break
- result += line
-
- result = "Botcore\n=======\n\n" + result
- file.write_text(result, encoding="utf-8")
-
- else:
- # Clean up the submodule name so it's just the name without the top level module name
- # example: `botcore.regex module` -> `regex`
- lines = file.read_text(encoding="utf-8").splitlines()
- lines[0] = lines[0].replace("botcore.", "").replace("module", "").strip()
-
- # Take the opportunity to configure autodoc
- lines = "\n".join(lines).replace("undoc-members", "special-members")
- file.write_text(lines, encoding="utf-8")
-
-
-__cleanup()
-
-
def skip(*args) -> bool:
"""Things that should be skipped by the autodoc generation."""
name = args[2]
@@ -140,11 +100,6 @@ def skip(*args) -> bool:
return would_skip
-def setup(app: Sphinx) -> None:
- """Add extra hook-based autodoc configuration."""
- app.connect("autodoc-skip-member", skip)
-
-
# -- Extension configuration -------------------------------------------------
# -- Options for todo extension ----------------------------------------------
@@ -169,60 +124,23 @@ extlinks = {
}
+# -- Options for intersphinx extension ---------------------------------------
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3", None),
+ "discord": ("https://discordpy.readthedocs.io/en/master/", None),
+}
+
+
+# -- Options for the autodoc extension ---------------------------------------
+utils.cleanup()
+autodoc_default_options = {
+ "members": True,
+ "special-members": True,
+ "show-inheritance": True,
+ "imported-members": False,
+ "exclude-members": "__weakref__"
+}
+
+
# -- Options for the linkcode extension --------------------------------------
-def linkcode_resolve(domain: str, info: dict[str, str]) -> typing.Optional[str]:
- """
- Function called by linkcode to get the URL for a given resource.
-
- See for more details:
- https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html#confval-linkcode_resolve
- """
- if domain != "py":
- raise Exception("Unknown domain passed to linkcode function.")
-
- symbol_name = info["fullname"]
-
- module = importlib.import_module(info["module"])
-
- symbol = [module]
- for name in symbol_name.split("."):
- symbol.append(getattr(symbol[-1], name))
- symbol_name = name
-
- try:
- lines, start = inspect.getsourcelines(symbol[-1])
- end = start + len(lines)
- except TypeError:
- # Find variables by parsing the ast
- source = ast.parse(inspect.getsource(symbol[-2]))
- while isinstance(source.body[0], ast.ClassDef):
- source = source.body[0]
-
- for ast_obj in source.body:
- if isinstance(ast_obj, ast.Assign):
- names = []
- for target in ast_obj.targets:
- if isinstance(target, ast.Tuple):
- names.extend([name.id for name in target.elts])
- else:
- names.append(target.id)
-
- if symbol_name in names:
- start, end = ast_obj.lineno, ast_obj.end_lineno
- break
- else:
- raise Exception(f"Could not find symbol `{symbol_name}` in {module.__name__}.")
-
- _, offset = inspect.getsourcelines(symbol[-2])
- if offset != 0:
- offset -= 1
- start += offset
- end += offset
-
- file = Path(inspect.getfile(module)).relative_to(PROJECT_ROOT).as_posix()
-
- url = f"{SOURCE_FILE_LINK}/{file}#L{start}"
- if end != start:
- url += f"-L{end}"
-
- return url
+linkcode_resolve = functools.partial(utils.linkcode_resolve, SOURCE_FILE_LINK)
diff --git a/docs/utils.py b/docs/utils.py
new file mode 100644
index 00000000..8bc69ccd
--- /dev/null
+++ b/docs/utils.py
@@ -0,0 +1,117 @@
+"""Utilities used in generating docs."""
+
+import ast
+import importlib
+import inspect
+import typing
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).parent.parent
+
+
+def linkcode_resolve(source_url: str, domain: str, info: dict[str, str]) -> typing.Optional[str]:
+ """
+ Function called by linkcode to get the URL for a given resource.
+
+ See for more details:
+ https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html#confval-linkcode_resolve
+ """
+ if domain != "py":
+ raise Exception("Unknown domain passed to linkcode function.")
+
+ symbol_name = info["fullname"]
+
+ module = importlib.import_module(info["module"])
+
+ symbol = [module]
+ for name in symbol_name.split("."):
+ symbol.append(getattr(symbol[-1], name))
+ symbol_name = name
+
+ try:
+ lines, start = inspect.getsourcelines(symbol[-1])
+ end = start + len(lines)
+ except TypeError:
+ # Find variables by parsing the ast
+ source = ast.parse(inspect.getsource(symbol[-2]))
+ while isinstance(source.body[0], ast.ClassDef):
+ source = source.body[0]
+
+ for ast_obj in source.body:
+ if isinstance(ast_obj, ast.Assign):
+ names = []
+ for target in ast_obj.targets:
+ if isinstance(target, ast.Tuple):
+ names.extend([name.id for name in target.elts])
+ else:
+ names.append(target.id)
+
+ if symbol_name in names:
+ start, end = ast_obj.lineno, ast_obj.end_lineno
+ break
+ else:
+ raise Exception(f"Could not find symbol `{symbol_name}` in {module.__name__}.")
+
+ _, offset = inspect.getsourcelines(symbol[-2])
+ if offset != 0:
+ offset -= 1
+ start += offset
+ end += offset
+
+ file = Path(inspect.getfile(module)).relative_to(PROJECT_ROOT).as_posix()
+
+ url = f"{source_url}/{file}#L{start}"
+ if end != start:
+ url += f"-L{end}"
+
+ return url
+
+
+def cleanup() -> None:
+ """Remove unneeded autogenerated doc files, and clean up others."""
+ included = __get_included()
+
+ for file in (PROJECT_ROOT / "docs" / "output").iterdir():
+ if file.name in ("botcore.rst", "botcore.exts.rst") and file.name in included:
+ content = file.read_text(encoding="utf-8").splitlines(keepends=True)
+
+ # Rename the extension to be less wordy
+ # Example: botcore.exts -> Botcore Exts
+ title = content[0].split()[0].strip().replace("botcore.", "").replace(".", " ").title()
+ title = f"{title}\n{'=' * len(title)}\n\n"
+ content[0:2] = title
+
+ file.write_text("".join(content), encoding="utf-8")
+
+ elif file.name in included:
+ # Clean up the submodule name so it's just the name without the top level module name
+ # example: `botcore.regex module` -> `regex`
+ lines = file.read_text(encoding="utf-8").splitlines(keepends=True)
+ lines[0] = lines[0].replace("module", "").strip().split(".")[-1] + "\n"
+ file.write_text("".join(lines))
+
+ else:
+ # These are files that have not been explicitly included in the docs via __all__
+ print("Deleted file", file.name)
+ file.unlink()
+ continue
+
+ # Take the opportunity to configure autodoc
+ content = file.read_text(encoding="utf-8").replace("undoc-members", "special-members")
+ file.write_text(content, encoding="utf-8")
+
+
+def __get_included() -> set[str]:
+ """Get a list of files that should be included in the final build."""
+
+ def get_all_from_module(module_name: str) -> set[str]:
+ module = importlib.import_module(module_name)
+ _modules = {module.__name__ + ".rst"}
+
+ if hasattr(module, "__all__"):
+ for sub_module in module.__all__:
+ _modules.update(get_all_from_module(sub_module))
+
+ return _modules
+
+ return get_all_from_module("botcore")
diff --git a/pyproject.toml b/pyproject.toml
index 99549f3a..37983b82 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,7 +45,7 @@ furo = "2022.1.2"
[tool.taskipy.tasks]
lint = "pre-commit run --all-files"
precommit = "pre-commit install"
-apidoc = "sphinx-apidoc -o docs/output botcore -fe"
+apidoc = "sphinx-apidoc -o docs/output botcore -feM"
builddoc = "sphinx-build -nW -j auto -b html docs docs/build"
docs = "task apidoc && task builddoc"
diff --git a/tox.ini b/tox.ini
index 9472c32f..e0145e7a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,7 +2,7 @@
max-line-length=120
docstring-convention=all
import-order-style=pycharm
-application_import_names=bot,tests
+application_import_names=botcore,docs,tests
exclude=.cache,.venv,.git,constants.py
ignore=
B311,W503,E226,S311,T000,E731