diff options
-rw-r--r-- | bot/decorators.py | 99 | ||||
-rw-r--r-- | bot/errors.py | 20 | ||||
-rw-r--r-- | bot/exts/backend/error_handler.py | 3 | ||||
-rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 4 | ||||
-rw-r--r-- | bot/exts/utils/reminders.py | 50 | ||||
-rw-r--r-- | bot/utils/function.py | 75 | ||||
-rw-r--r-- | bot/utils/lock.py | 114 |
7 files changed, 265 insertions, 100 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index 2518124da..063c8f878 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,16 +1,15 @@ +import asyncio import logging -import random -from asyncio import Lock, create_task, sleep +import typing as t from contextlib import suppress from functools import wraps -from typing import Callable, Container, Optional, Union -from weakref import WeakValueDictionary -from discord import Colour, Embed, Member, NotFound +from discord import Member, NotFound from discord.ext import commands from discord.ext.commands import Cog, Context -from bot.constants import Channels, ERROR_REPLIES, RedirectOutput +from bot.constants import Channels, RedirectOutput +from bot.utils import function from bot.utils.checks import in_whitelist_check log = logging.getLogger(__name__) @@ -18,12 +17,12 @@ log = logging.getLogger(__name__) def in_whitelist( *, - channels: Container[int] = (), - categories: Container[int] = (), - roles: Container[int] = (), - redirect: Optional[int] = Channels.bot_commands, + channels: t.Container[int] = (), + categories: t.Container[int] = (), + roles: t.Container[int] = (), + redirect: t.Optional[int] = Channels.bot_commands, fail_silently: bool = False, -) -> Callable: +) -> t.Callable: """ Check if a command was issued in a whitelisted context. @@ -31,7 +30,7 @@ def in_whitelist( - `channels`: a container with channel ids for whitelisted channels - `categories`: a container with category ids for whitelisted categories - - `roles`: a container with with role ids for whitelisted roles + - `roles`: a container with role ids for whitelisted roles If the command was invoked in a context that was not whitelisted, the member is either redirected to the `redirect` channel that was passed (default: #bot-commands) or simply @@ -44,7 +43,7 @@ def in_whitelist( return commands.check(predicate) -def has_no_roles(*roles: Union[str, int]) -> Callable: +def has_no_roles(*roles: t.Union[str, int]) -> t.Callable: """ Returns True if the user does not have any of the roles specified. @@ -63,39 +62,7 @@ def has_no_roles(*roles: Union[str, int]) -> Callable: return commands.check(predicate) -def locked() -> Callable: - """ - Allows the user to only run one instance of the decorated command at a time. - - Subsequent calls to the command from the same author are ignored until the command has completed invocation. - - This decorator must go before (below) the `command` decorator. - """ - def wrap(func: Callable) -> Callable: - func.__locks = WeakValueDictionary() - - @wraps(func) - async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: - lock = func.__locks.setdefault(ctx.author.id, Lock()) - if lock.locked(): - embed = Embed() - embed.colour = Colour.red() - - log.debug("User tried to invoke a locked command.") - embed.description = ( - "You're already using this command. Please wait until it is done before you use it again." - ) - embed.title = random.choice(ERROR_REPLIES) - await ctx.send(embed=embed) - return - - async with func.__locks.setdefault(ctx.author.id, Lock()): - await func(self, ctx, *args, **kwargs) - return inner - return wrap - - -def redirect_output(destination_channel: int, bypass_roles: Container[int] = None) -> Callable: +def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable: """ Changes the channel in the context of the command to redirect the output to a certain channel. @@ -103,7 +70,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non This decorator must go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Callable: + def wrap(func: t.Callable) -> t.Callable: @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: if ctx.channel.id == destination_channel: @@ -122,14 +89,14 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non log.trace(f"Redirecting output of {ctx.author}'s command '{ctx.command.name}' to {redirect_channel.name}") ctx.channel = redirect_channel await ctx.channel.send(f"Here's the output of your command, {ctx.author.mention}") - create_task(func(self, ctx, *args, **kwargs)) + asyncio.create_task(func(self, ctx, *args, **kwargs)) message = await old_channel.send( f"Hey, {ctx.author.mention}, you can find the output of your command here: " f"{redirect_channel.mention}" ) if RedirectOutput.delete_invocation: - await sleep(RedirectOutput.delete_delay) + await asyncio.sleep(RedirectOutput.delete_delay) with suppress(NotFound): await message.delete() @@ -143,38 +110,35 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non return wrap -def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable: +def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable: """ Ensure the highest role of the invoking member is greater than that of the target member. If the condition fails, a warning is sent to the invoking context. A target which is not an instance of discord.Member will always pass. - A value of 0 (i.e. position 0) for `target_arg` corresponds to the argument which comes after - `ctx`. If the target argument is a kwarg, its name can instead be given. + `member_arg` is the keyword name or position index of the parameter of the decorated command + whose value is the target member. This decorator must go before (below) the `command` decorator. """ - def wrap(func: Callable) -> Callable: + def decorator(func: t.Callable) -> t.Callable: @wraps(func) - async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: - try: - target = kwargs[target_arg] - except KeyError: - try: - target = args[target_arg] - except IndexError: - raise ValueError(f"Could not find target argument at position {target_arg}") - except TypeError: - raise ValueError(f"Could not find target kwarg with key {target_arg!r}") + async def wrapper(*args, **kwargs) -> None: + log.trace(f"{func.__name__}: respect role hierarchy decorator called") + + bound_args = function.get_bound_args(func, args, kwargs) + target = function.get_arg_value(member_arg, bound_args) if not isinstance(target, Member): log.trace("The target is not a discord.Member; skipping role hierarchy check.") - await func(self, ctx, *args, **kwargs) + await func(*args, **kwargs) return + ctx = function.get_arg_value(1, bound_args) cmd = ctx.command.name actor = ctx.author + if target.top_role >= actor.top_role: log.info( f"{actor} ({actor.id}) attempted to {cmd} " @@ -185,6 +149,7 @@ def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable: "someone with an equal or higher top role." ) else: - await func(self, ctx, *args, **kwargs) - return inner - return wrap + log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func") + await func(*args, **kwargs) + return wrapper + return decorator diff --git a/bot/errors.py b/bot/errors.py new file mode 100644 index 000000000..65d715203 --- /dev/null +++ b/bot/errors.py @@ -0,0 +1,20 @@ +from typing import Hashable + + +class LockedResourceError(RuntimeError): + """ + Exception raised when an operation is attempted on a locked resource. + + Attributes: + `type` -- name of the locked resource's type + `id` -- ID of the locked resource + """ + + def __init__(self, resource_type: str, resource_id: Hashable): + self.type = resource_type + self.id = resource_id + + super().__init__( + f"Cannot operate on {self.type.lower()} `{self.id}`; " + "it is currently locked and in use by another operation." + ) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index f9d4de638..c643d346e 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -10,6 +10,7 @@ from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels, Colours from bot.converters import TagNameConverter +from bot.errors import LockedResourceError from bot.utils.checks import InWhitelistCheckFailure log = logging.getLogger(__name__) @@ -75,6 +76,8 @@ class ErrorHandler(Cog): elif isinstance(e, errors.CommandInvokeError): if isinstance(e.original, ResponseCodeError): await self.handle_api_error(ctx, e.original) + elif isinstance(e.original, LockedResourceError): + await ctx.send(f"{e.original} Please wait for it to finish and try again later.") else: await self.handle_unexpected_error(ctx, e.original) return # Exit early to avoid logging. diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index ef6f6e3c6..a8b3feb38 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -230,7 +230,7 @@ class Infractions(InfractionScheduler, commands.Cog): await self.apply_infraction(ctx, infraction, user, action()) - @respect_role_hierarchy() + @respect_role_hierarchy(member_arg=2) async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: """Apply a kick infraction with kwargs passed to `post_infraction`.""" infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs) @@ -245,7 +245,7 @@ class Infractions(InfractionScheduler, commands.Cog): action = user.kick(reason=reason) await self.apply_infraction(ctx, infraction, user, action) - @respect_role_hierarchy() + @respect_role_hierarchy(member_arg=2) async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: """ Apply a ban infraction with kwargs passed to `post_infraction`. diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index b8c15853f..bf4e24661 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -16,12 +16,14 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role from bot.converters import Duration from bot.pagination import LinePaginator from bot.utils.checks import has_any_role_check, has_no_roles_check +from bot.utils.lock import lock_arg from bot.utils.messages import send_denial from bot.utils.scheduling import Scheduler from bot.utils.time import humanize_delta log = logging.getLogger(__name__) +NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 @@ -52,7 +54,7 @@ class Reminders(Cog): now = datetime.utcnow() for reminder in response: - is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + is_valid, *_ = self.ensure_valid_reminder(reminder) if not is_valid: continue @@ -65,11 +67,7 @@ class Reminders(Cog): else: self.schedule_reminder(reminder) - def ensure_valid_reminder( - self, - reminder: dict, - cancel_task: bool = True - ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.User, discord.TextChannel]: """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" user = self.bot.get_user(reminder['author']) channel = self.bot.get_channel(reminder['channel_id']) @@ -80,7 +78,7 @@ class Reminders(Cog): f"Reminder {reminder['id']} invalid: " f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." ) - asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + asyncio.create_task(self.bot.api_client.delete(f"bot/reminders/{reminder['id']}")) return is_valid, user, channel @@ -88,7 +86,7 @@ class Reminders(Cog): async def _send_confirmation( ctx: Context, on_success: str, - reminder_id: str, + reminder_id: t.Union[str, int], delivery_dt: t.Optional[datetime], ) -> None: """Send an embed confirming the reminder change was made successfully.""" @@ -148,24 +146,8 @@ class Reminders(Cog): def schedule_reminder(self, reminder: dict) -> None: """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" - reminder_id = reminder["id"] reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) - - async def _remind() -> None: - await self.send_reminder(reminder) - - log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") - await self._delete_reminder(reminder_id) - - self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind()) - - async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: - """Delete a reminder from the database, given its ID, and cancel the running task.""" - await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - - if cancel_task: - # Now we can remove it from the schedule list - self.scheduler.cancel(reminder_id) + self.scheduler.schedule_at(reminder_datetime, reminder["id"], self.send_reminder(reminder)) async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict: """ @@ -188,10 +170,12 @@ class Reminders(Cog): log.trace(f"Scheduling new task #{reminder['id']}") self.schedule_reminder(reminder) + @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True) async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" is_valid, user, channel = self.ensure_valid_reminder(reminder) if not is_valid: + # No need to cancel the task too; it'll simply be done once this coroutine returns. return embed = discord.Embed() @@ -217,11 +201,10 @@ class Reminders(Cog): mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"]) ) - await channel.send( - content=f"{user.mention} {additional_mentions}", - embed=embed - ) - await self._delete_reminder(reminder["id"]) + await channel.send(content=f"{user.mention} {additional_mentions}", embed=embed) + + log.debug(f"Deleting reminder #{reminder['id']} (the user has been reminded).") + await self.bot.api_client.delete(f"bot/reminders/{reminder['id']}") @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) async def remind_group( @@ -395,6 +378,7 @@ class Reminders(Cog): mention_ids = [mention.id for mention in mentions] await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) + @lock_arg(NAMESPACE, "id_", raise_error=True) async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None: """Edits a reminder with the given payload, then sends a confirmation message.""" if not await self._can_modify(ctx, id_): @@ -414,11 +398,15 @@ class Reminders(Cog): await self._reschedule_reminder(reminder) @remind_group.command("delete", aliases=("remove", "cancel")) + @lock_arg(NAMESPACE, "id_", raise_error=True) async def delete_reminder(self, ctx: Context, id_: int) -> None: """Delete one of your active reminders.""" if not await self._can_modify(ctx, id_): return - await self._delete_reminder(id_) + + await self.bot.api_client.delete(f"bot/reminders/{id_}") + self.scheduler.cancel(id_) + await self._send_confirmation( ctx, on_success="That reminder has been deleted successfully!", diff --git a/bot/utils/function.py b/bot/utils/function.py new file mode 100644 index 000000000..3ab32fe3c --- /dev/null +++ b/bot/utils/function.py @@ -0,0 +1,75 @@ +"""Utilities for interaction with functions.""" + +import inspect +import typing as t + +Argument = t.Union[int, str] +BoundArgs = t.OrderedDict[str, t.Any] +Decorator = t.Callable[[t.Callable], t.Callable] +ArgValGetter = t.Callable[[BoundArgs], t.Any] + + +def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any: + """ + Return a value from `arguments` based on a name or position. + + `arguments` is an ordered mapping of parameter names to argument values. + + Raise TypeError if `name_or_pos` isn't a str or int. + Raise ValueError if `name_or_pos` does not match any argument. + """ + if isinstance(name_or_pos, int): + # Convert arguments to a tuple to make them indexable. + arg_values = tuple(arguments.items()) + arg_pos = name_or_pos + + try: + name, value = arg_values[arg_pos] + return value + except IndexError: + raise ValueError(f"Argument position {arg_pos} is out of bounds.") + elif isinstance(name_or_pos, str): + arg_name = name_or_pos + try: + return arguments[arg_name] + except KeyError: + raise ValueError(f"Argument {arg_name!r} doesn't exist.") + else: + raise TypeError("'arg' must either be an int (positional index) or a str (keyword).") + + +def get_arg_value_wrapper( + decorator_func: t.Callable[[ArgValGetter], Decorator], + name_or_pos: Argument, + func: t.Callable[[t.Any], t.Any] = None, +) -> Decorator: + """ + Call `decorator_func` with the value of the arg at the given name/position. + + `decorator_func` must accept a callable as a parameter to which it will pass a mapping of + parameter names to argument values of the function it's decorating. + + `func` is an optional callable which will return a new value given the argument's value. + + Return the decorator returned by `decorator_func`. + """ + def wrapper(args: BoundArgs) -> t.Any: + value = get_arg_value(name_or_pos, args) + if func: + value = func(value) + return value + + return decorator_func(wrapper) + + +def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> BoundArgs: + """ + Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values. + + Default parameter values are also set. + """ + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + return bound_args.arguments diff --git a/bot/utils/lock.py b/bot/utils/lock.py new file mode 100644 index 000000000..7aaafbc88 --- /dev/null +++ b/bot/utils/lock.py @@ -0,0 +1,114 @@ +import inspect +import logging +from collections import defaultdict +from functools import partial, wraps +from typing import Any, Awaitable, Callable, Hashable, Union +from weakref import WeakValueDictionary + +from bot.errors import LockedResourceError +from bot.utils import function + +log = logging.getLogger(__name__) +__lock_dicts = defaultdict(WeakValueDictionary) + +_IdCallableReturn = Union[Hashable, Awaitable[Hashable]] +_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn] +ResourceId = Union[Hashable, _IdCallable] + + +class LockGuard: + """ + A context manager which acquires and releases a lock (mutex). + + Raise RuntimeError if trying to acquire a locked lock. + """ + + def __init__(self): + self._locked = False + + @property + def locked(self) -> bool: + """Return True if currently locked or False if unlocked.""" + return self._locked + + def __enter__(self): + if self._locked: + raise RuntimeError("Cannot acquire a locked lock.") + + self._locked = True + + def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001 + self._locked = False + return False # Indicate any raised exception shouldn't be suppressed. + + +def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable: + """ + Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`. + + If any other mutually exclusive function currently holds the lock for a resource, do not run the + decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if + the lock cannot be acquired. + + `namespace` is an identifier used to prevent collisions among resource IDs. + + `resource_id` identifies a resource on which to perform a mutually exclusive operation. + It may also be a callable or awaitable which will return the resource ID given an ordered + mapping of the parameters' names to arguments' values. + + If decorating a command, this decorator must go before (below) the `command` decorator. + """ + def decorator(func: Callable) -> Callable: + name = func.__name__ + + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + log.trace(f"{name}: mutually exclusive decorator called") + + if callable(resource_id): + log.trace(f"{name}: binding args to signature") + bound_args = function.get_bound_args(func, args, kwargs) + + log.trace(f"{name}: calling the given callable to get the resource ID") + id_ = resource_id(bound_args) + + if inspect.isawaitable(id_): + log.trace(f"{name}: awaiting to get resource ID") + id_ = await id_ + else: + id_ = resource_id + + log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}") + + # Get the lock for the ID. Create a lock if one doesn't exist yet. + locks = __lock_dicts[namespace] + lock_guard = locks.setdefault(id_, LockGuard()) + + if not lock_guard.locked: + log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...") + with lock_guard: + return await func(*args, **kwargs) + else: + log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked") + if raise_error: + raise LockedResourceError(str(namespace), id_) + + return wrapper + return decorator + + +def lock_arg( + namespace: Hashable, + name_or_pos: function.Argument, + func: Callable[[Any], _IdCallableReturn] = None, + *, + raise_error: bool = False, +) -> Callable: + """ + Apply the `lock` decorator using the value of the arg at the given name/position as the ID. + + `func` is an optional callable or awaitable which will return the ID given the argument value. + See `lock` docs for more information. + """ + decorator_func = partial(lock, namespace, raise_error=raise_error) + return function.get_arg_value_wrapper(decorator_func, name_or_pos, func) |