aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/decorators.py34
1 files changed, 24 insertions, 10 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index d9e5e3a83..1fe082b6e 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -14,6 +14,8 @@ from bot.utils.checks import in_whitelist_check, with_role_check, without_role_c
log = logging.getLogger(__name__)
+Argument = t.Union[int, str]
+
def in_whitelist(
*,
@@ -138,7 +140,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N
return wrap
-def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
+def respect_role_hierarchy(target_arg: Argument = 0) -> t.Callable:
"""
Ensure the highest role of the invoking member is greater than that of the target member.
@@ -153,15 +155,7 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- try:
- target = kwargs[target_arg]
- except KeyError:
- try:
- target = args[target_arg]
- except IndexError:
- raise ValueError(f"Could not find target argument at position {target_arg}")
- except TypeError:
- raise ValueError(f"Could not find target kwarg with key {target_arg!r}")
+ target = _get_arg_value(target_arg, args, kwargs)
if not isinstance(target, Member):
log.trace("The target is not a discord.Member; skipping role hierarchy check.")
@@ -183,3 +177,23 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
await func(self, ctx, *args, **kwargs)
return inner
return wrap
+
+
+def _get_arg_value(target_arg: Argument, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> t.Any:
+ """
+ Return the value of the arg at the given position or name `target_arg`.
+
+ Use an integer as a position if the target argument is positional.
+ Use a string as a parameter name if the target argument is a keyword argument.
+
+ Raise ValueError if `target_arg` cannot be found.
+ """
+ try:
+ return kwargs[target_arg]
+ except KeyError:
+ try:
+ return args[target_arg]
+ except IndexError:
+ raise ValueError(f"Could not find target argument at position {target_arg}")
+ except TypeError:
+ raise ValueError(f"Could not find target kwarg with key {target_arg!r}")