aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/decorators.py99
-rw-r--r--bot/errors.py20
-rw-r--r--bot/exts/backend/error_handler.py3
-rw-r--r--bot/exts/moderation/infraction/infractions.py4
-rw-r--r--bot/exts/utils/reminders.py50
-rw-r--r--bot/utils/function.py75
-rw-r--r--bot/utils/lock.py114
7 files changed, 265 insertions, 100 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index 2518124da..063c8f878 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,16 +1,15 @@
+import asyncio
import logging
-import random
-from asyncio import Lock, create_task, sleep
+import typing as t
from contextlib import suppress
from functools import wraps
-from typing import Callable, Container, Optional, Union
-from weakref import WeakValueDictionary
-from discord import Colour, Embed, Member, NotFound
+from discord import Member, NotFound
from discord.ext import commands
from discord.ext.commands import Cog, Context
-from bot.constants import Channels, ERROR_REPLIES, RedirectOutput
+from bot.constants import Channels, RedirectOutput
+from bot.utils import function
from bot.utils.checks import in_whitelist_check
log = logging.getLogger(__name__)
@@ -18,12 +17,12 @@ log = logging.getLogger(__name__)
def in_whitelist(
*,
- channels: Container[int] = (),
- categories: Container[int] = (),
- roles: Container[int] = (),
- redirect: Optional[int] = Channels.bot_commands,
+ channels: t.Container[int] = (),
+ categories: t.Container[int] = (),
+ roles: t.Container[int] = (),
+ redirect: t.Optional[int] = Channels.bot_commands,
fail_silently: bool = False,
-) -> Callable:
+) -> t.Callable:
"""
Check if a command was issued in a whitelisted context.
@@ -31,7 +30,7 @@ def in_whitelist(
- `channels`: a container with channel ids for whitelisted channels
- `categories`: a container with category ids for whitelisted categories
- - `roles`: a container with with role ids for whitelisted roles
+ - `roles`: a container with role ids for whitelisted roles
If the command was invoked in a context that was not whitelisted, the member is either
redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
@@ -44,7 +43,7 @@ def in_whitelist(
return commands.check(predicate)
-def has_no_roles(*roles: Union[str, int]) -> Callable:
+def has_no_roles(*roles: t.Union[str, int]) -> t.Callable:
"""
Returns True if the user does not have any of the roles specified.
@@ -63,39 +62,7 @@ def has_no_roles(*roles: Union[str, int]) -> Callable:
return commands.check(predicate)
-def locked() -> Callable:
- """
- Allows the user to only run one instance of the decorated command at a time.
-
- Subsequent calls to the command from the same author are ignored until the command has completed invocation.
-
- This decorator must go before (below) the `command` decorator.
- """
- def wrap(func: Callable) -> Callable:
- func.__locks = WeakValueDictionary()
-
- @wraps(func)
- async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- lock = func.__locks.setdefault(ctx.author.id, Lock())
- if lock.locked():
- embed = Embed()
- embed.colour = Colour.red()
-
- log.debug("User tried to invoke a locked command.")
- embed.description = (
- "You're already using this command. Please wait until it is done before you use it again."
- )
- embed.title = random.choice(ERROR_REPLIES)
- await ctx.send(embed=embed)
- return
-
- async with func.__locks.setdefault(ctx.author.id, Lock()):
- await func(self, ctx, *args, **kwargs)
- return inner
- return wrap
-
-
-def redirect_output(destination_channel: int, bypass_roles: Container[int] = None) -> Callable:
+def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = None) -> t.Callable:
"""
Changes the channel in the context of the command to redirect the output to a certain channel.
@@ -103,7 +70,7 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
if ctx.channel.id == destination_channel:
@@ -122,14 +89,14 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
log.trace(f"Redirecting output of {ctx.author}'s command '{ctx.command.name}' to {redirect_channel.name}")
ctx.channel = redirect_channel
await ctx.channel.send(f"Here's the output of your command, {ctx.author.mention}")
- create_task(func(self, ctx, *args, **kwargs))
+ asyncio.create_task(func(self, ctx, *args, **kwargs))
message = await old_channel.send(
f"Hey, {ctx.author.mention}, you can find the output of your command here: "
f"{redirect_channel.mention}"
)
if RedirectOutput.delete_invocation:
- await sleep(RedirectOutput.delete_delay)
+ await asyncio.sleep(RedirectOutput.delete_delay)
with suppress(NotFound):
await message.delete()
@@ -143,38 +110,35 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
return wrap
-def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
+def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:
"""
Ensure the highest role of the invoking member is greater than that of the target member.
If the condition fails, a warning is sent to the invoking context. A target which is not an
instance of discord.Member will always pass.
- A value of 0 (i.e. position 0) for `target_arg` corresponds to the argument which comes after
- `ctx`. If the target argument is a kwarg, its name can instead be given.
+ `member_arg` is the keyword name or position index of the parameter of the decorated command
+ whose value is the target member.
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: Callable) -> Callable:
+ def decorator(func: t.Callable) -> t.Callable:
@wraps(func)
- async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- try:
- target = kwargs[target_arg]
- except KeyError:
- try:
- target = args[target_arg]
- except IndexError:
- raise ValueError(f"Could not find target argument at position {target_arg}")
- except TypeError:
- raise ValueError(f"Could not find target kwarg with key {target_arg!r}")
+ async def wrapper(*args, **kwargs) -> None:
+ log.trace(f"{func.__name__}: respect role hierarchy decorator called")
+
+ bound_args = function.get_bound_args(func, args, kwargs)
+ target = function.get_arg_value(member_arg, bound_args)
if not isinstance(target, Member):
log.trace("The target is not a discord.Member; skipping role hierarchy check.")
- await func(self, ctx, *args, **kwargs)
+ await func(*args, **kwargs)
return
+ ctx = function.get_arg_value(1, bound_args)
cmd = ctx.command.name
actor = ctx.author
+
if target.top_role >= actor.top_role:
log.info(
f"{actor} ({actor.id}) attempted to {cmd} "
@@ -185,6 +149,7 @@ def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
"someone with an equal or higher top role."
)
else:
- await func(self, ctx, *args, **kwargs)
- return inner
- return wrap
+ log.trace(f"{func.__name__}: {target.top_role=} < {actor.top_role=}; calling func")
+ await func(*args, **kwargs)
+ return wrapper
+ return decorator
diff --git a/bot/errors.py b/bot/errors.py
new file mode 100644
index 000000000..65d715203
--- /dev/null
+++ b/bot/errors.py
@@ -0,0 +1,20 @@
+from typing import Hashable
+
+
+class LockedResourceError(RuntimeError):
+ """
+ Exception raised when an operation is attempted on a locked resource.
+
+ Attributes:
+ `type` -- name of the locked resource's type
+ `id` -- ID of the locked resource
+ """
+
+ def __init__(self, resource_type: str, resource_id: Hashable):
+ self.type = resource_type
+ self.id = resource_id
+
+ super().__init__(
+ f"Cannot operate on {self.type.lower()} `{self.id}`; "
+ "it is currently locked and in use by another operation."
+ )
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index f9d4de638..c643d346e 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -10,6 +10,7 @@ from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels, Colours
from bot.converters import TagNameConverter
+from bot.errors import LockedResourceError
from bot.utils.checks import InWhitelistCheckFailure
log = logging.getLogger(__name__)
@@ -75,6 +76,8 @@ class ErrorHandler(Cog):
elif isinstance(e, errors.CommandInvokeError):
if isinstance(e.original, ResponseCodeError):
await self.handle_api_error(ctx, e.original)
+ elif isinstance(e.original, LockedResourceError):
+ await ctx.send(f"{e.original} Please wait for it to finish and try again later.")
else:
await self.handle_unexpected_error(ctx, e.original)
return # Exit early to avoid logging.
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index ef6f6e3c6..a8b3feb38 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -230,7 +230,7 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user, action())
- @respect_role_hierarchy()
+ @respect_role_hierarchy(member_arg=2)
async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None:
"""Apply a kick infraction with kwargs passed to `post_infraction`."""
infraction = await _utils.post_infraction(ctx, user, "kick", reason, active=False, **kwargs)
@@ -245,7 +245,7 @@ class Infractions(InfractionScheduler, commands.Cog):
action = user.kick(reason=reason)
await self.apply_infraction(ctx, infraction, user, action)
- @respect_role_hierarchy()
+ @respect_role_hierarchy(member_arg=2)
async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None:
"""
Apply a ban infraction with kwargs passed to `post_infraction`.
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index b8c15853f..bf4e24661 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -16,12 +16,14 @@ from bot.constants import Guild, Icons, MODERATION_ROLES, POSITIVE_REPLIES, Role
from bot.converters import Duration
from bot.pagination import LinePaginator
from bot.utils.checks import has_any_role_check, has_no_roles_check
+from bot.utils.lock import lock_arg
from bot.utils.messages import send_denial
from bot.utils.scheduling import Scheduler
from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
+NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos
WHITELISTED_CHANNELS = Guild.reminder_whitelist
MAXIMUM_REMINDERS = 5
@@ -52,7 +54,7 @@ class Reminders(Cog):
now = datetime.utcnow()
for reminder in response:
- is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False)
+ is_valid, *_ = self.ensure_valid_reminder(reminder)
if not is_valid:
continue
@@ -65,11 +67,7 @@ class Reminders(Cog):
else:
self.schedule_reminder(reminder)
- def ensure_valid_reminder(
- self,
- reminder: dict,
- cancel_task: bool = True
- ) -> t.Tuple[bool, discord.User, discord.TextChannel]:
+ def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.User, discord.TextChannel]:
"""Ensure reminder author and channel can be fetched otherwise delete the reminder."""
user = self.bot.get_user(reminder['author'])
channel = self.bot.get_channel(reminder['channel_id'])
@@ -80,7 +78,7 @@ class Reminders(Cog):
f"Reminder {reminder['id']} invalid: "
f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}."
)
- asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task))
+ asyncio.create_task(self.bot.api_client.delete(f"bot/reminders/{reminder['id']}"))
return is_valid, user, channel
@@ -88,7 +86,7 @@ class Reminders(Cog):
async def _send_confirmation(
ctx: Context,
on_success: str,
- reminder_id: str,
+ reminder_id: t.Union[str, int],
delivery_dt: t.Optional[datetime],
) -> None:
"""Send an embed confirming the reminder change was made successfully."""
@@ -148,24 +146,8 @@ class Reminders(Cog):
def schedule_reminder(self, reminder: dict) -> None:
"""A coroutine which sends the reminder once the time is reached, and cancels the running task."""
- reminder_id = reminder["id"]
reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None)
-
- async def _remind() -> None:
- await self.send_reminder(reminder)
-
- log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).")
- await self._delete_reminder(reminder_id)
-
- self.scheduler.schedule_at(reminder_datetime, reminder_id, _remind())
-
- async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:
- """Delete a reminder from the database, given its ID, and cancel the running task."""
- await self.bot.api_client.delete('bot/reminders/' + str(reminder_id))
-
- if cancel_task:
- # Now we can remove it from the schedule list
- self.scheduler.cancel(reminder_id)
+ self.scheduler.schedule_at(reminder_datetime, reminder["id"], self.send_reminder(reminder))
async def _edit_reminder(self, reminder_id: int, payload: dict) -> dict:
"""
@@ -188,10 +170,12 @@ class Reminders(Cog):
log.trace(f"Scheduling new task #{reminder['id']}")
self.schedule_reminder(reminder)
+ @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
is_valid, user, channel = self.ensure_valid_reminder(reminder)
if not is_valid:
+ # No need to cancel the task too; it'll simply be done once this coroutine returns.
return
embed = discord.Embed()
@@ -217,11 +201,10 @@ class Reminders(Cog):
mentionable.mention for mentionable in self.get_mentionables(reminder["mentions"])
)
- await channel.send(
- content=f"{user.mention} {additional_mentions}",
- embed=embed
- )
- await self._delete_reminder(reminder["id"])
+ await channel.send(content=f"{user.mention} {additional_mentions}", embed=embed)
+
+ log.debug(f"Deleting reminder #{reminder['id']} (the user has been reminded).")
+ await self.bot.api_client.delete(f"bot/reminders/{reminder['id']}")
@group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True)
async def remind_group(
@@ -395,6 +378,7 @@ class Reminders(Cog):
mention_ids = [mention.id for mention in mentions]
await self.edit_reminder(ctx, id_, {"mentions": mention_ids})
+ @lock_arg(NAMESPACE, "id_", raise_error=True)
async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:
"""Edits a reminder with the given payload, then sends a confirmation message."""
if not await self._can_modify(ctx, id_):
@@ -414,11 +398,15 @@ class Reminders(Cog):
await self._reschedule_reminder(reminder)
@remind_group.command("delete", aliases=("remove", "cancel"))
+ @lock_arg(NAMESPACE, "id_", raise_error=True)
async def delete_reminder(self, ctx: Context, id_: int) -> None:
"""Delete one of your active reminders."""
if not await self._can_modify(ctx, id_):
return
- await self._delete_reminder(id_)
+
+ await self.bot.api_client.delete(f"bot/reminders/{id_}")
+ self.scheduler.cancel(id_)
+
await self._send_confirmation(
ctx,
on_success="That reminder has been deleted successfully!",
diff --git a/bot/utils/function.py b/bot/utils/function.py
new file mode 100644
index 000000000..3ab32fe3c
--- /dev/null
+++ b/bot/utils/function.py
@@ -0,0 +1,75 @@
+"""Utilities for interaction with functions."""
+
+import inspect
+import typing as t
+
+Argument = t.Union[int, str]
+BoundArgs = t.OrderedDict[str, t.Any]
+Decorator = t.Callable[[t.Callable], t.Callable]
+ArgValGetter = t.Callable[[BoundArgs], t.Any]
+
+
+def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:
+ """
+ Return a value from `arguments` based on a name or position.
+
+ `arguments` is an ordered mapping of parameter names to argument values.
+
+ Raise TypeError if `name_or_pos` isn't a str or int.
+ Raise ValueError if `name_or_pos` does not match any argument.
+ """
+ if isinstance(name_or_pos, int):
+ # Convert arguments to a tuple to make them indexable.
+ arg_values = tuple(arguments.items())
+ arg_pos = name_or_pos
+
+ try:
+ name, value = arg_values[arg_pos]
+ return value
+ except IndexError:
+ raise ValueError(f"Argument position {arg_pos} is out of bounds.")
+ elif isinstance(name_or_pos, str):
+ arg_name = name_or_pos
+ try:
+ return arguments[arg_name]
+ except KeyError:
+ raise ValueError(f"Argument {arg_name!r} doesn't exist.")
+ else:
+ raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
+
+
+def get_arg_value_wrapper(
+ decorator_func: t.Callable[[ArgValGetter], Decorator],
+ name_or_pos: Argument,
+ func: t.Callable[[t.Any], t.Any] = None,
+) -> Decorator:
+ """
+ Call `decorator_func` with the value of the arg at the given name/position.
+
+ `decorator_func` must accept a callable as a parameter to which it will pass a mapping of
+ parameter names to argument values of the function it's decorating.
+
+ `func` is an optional callable which will return a new value given the argument's value.
+
+ Return the decorator returned by `decorator_func`.
+ """
+ def wrapper(args: BoundArgs) -> t.Any:
+ value = get_arg_value(name_or_pos, args)
+ if func:
+ value = func(value)
+ return value
+
+ return decorator_func(wrapper)
+
+
+def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> BoundArgs:
+ """
+ Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
+
+ Default parameter values are also set.
+ """
+ sig = inspect.signature(func)
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
+ return bound_args.arguments
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
new file mode 100644
index 000000000..7aaafbc88
--- /dev/null
+++ b/bot/utils/lock.py
@@ -0,0 +1,114 @@
+import inspect
+import logging
+from collections import defaultdict
+from functools import partial, wraps
+from typing import Any, Awaitable, Callable, Hashable, Union
+from weakref import WeakValueDictionary
+
+from bot.errors import LockedResourceError
+from bot.utils import function
+
+log = logging.getLogger(__name__)
+__lock_dicts = defaultdict(WeakValueDictionary)
+
+_IdCallableReturn = Union[Hashable, Awaitable[Hashable]]
+_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
+ResourceId = Union[Hashable, _IdCallable]
+
+
+class LockGuard:
+ """
+ A context manager which acquires and releases a lock (mutex).
+
+ Raise RuntimeError if trying to acquire a locked lock.
+ """
+
+ def __init__(self):
+ self._locked = False
+
+ @property
+ def locked(self) -> bool:
+ """Return True if currently locked or False if unlocked."""
+ return self._locked
+
+ def __enter__(self):
+ if self._locked:
+ raise RuntimeError("Cannot acquire a locked lock.")
+
+ self._locked = True
+
+ def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: ANN001
+ self._locked = False
+ return False # Indicate any raised exception shouldn't be suppressed.
+
+
+def lock(namespace: Hashable, resource_id: ResourceId, *, raise_error: bool = False) -> Callable:
+ """
+ Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
+
+ If any other mutually exclusive function currently holds the lock for a resource, do not run the
+ decorated function and return None. If `raise_error` is True, raise `LockedResourceError` if
+ the lock cannot be acquired.
+
+ `namespace` is an identifier used to prevent collisions among resource IDs.
+
+ `resource_id` identifies a resource on which to perform a mutually exclusive operation.
+ It may also be a callable or awaitable which will return the resource ID given an ordered
+ mapping of the parameters' names to arguments' values.
+
+ If decorating a command, this decorator must go before (below) the `command` decorator.
+ """
+ def decorator(func: Callable) -> Callable:
+ name = func.__name__
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs) -> Any:
+ log.trace(f"{name}: mutually exclusive decorator called")
+
+ if callable(resource_id):
+ log.trace(f"{name}: binding args to signature")
+ bound_args = function.get_bound_args(func, args, kwargs)
+
+ log.trace(f"{name}: calling the given callable to get the resource ID")
+ id_ = resource_id(bound_args)
+
+ if inspect.isawaitable(id_):
+ log.trace(f"{name}: awaiting to get resource ID")
+ id_ = await id_
+ else:
+ id_ = resource_id
+
+ log.trace(f"{name}: getting lock for resource {id_!r} under namespace {namespace!r}")
+
+ # Get the lock for the ID. Create a lock if one doesn't exist yet.
+ locks = __lock_dicts[namespace]
+ lock_guard = locks.setdefault(id_, LockGuard())
+
+ if not lock_guard.locked:
+ log.debug(f"{name}: resource {namespace!r}:{id_!r} is free; acquiring it...")
+ with lock_guard:
+ return await func(*args, **kwargs)
+ else:
+ log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
+ if raise_error:
+ raise LockedResourceError(str(namespace), id_)
+
+ return wrapper
+ return decorator
+
+
+def lock_arg(
+ namespace: Hashable,
+ name_or_pos: function.Argument,
+ func: Callable[[Any], _IdCallableReturn] = None,
+ *,
+ raise_error: bool = False,
+) -> Callable:
+ """
+ Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
+
+ `func` is an optional callable or awaitable which will return the ID given the argument value.
+ See `lock` docs for more information.
+ """
+ decorator_func = partial(lock, namespace, raise_error=raise_error)
+ return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)