aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-07-17 10:13:52 -0700
committerGravatar MarkKoz <[email protected]>2020-07-31 22:58:05 -0700
commit5bc3c1e9732955020d8a5f92a8c1952bca3dae0c (patch)
tree4f31292f95bd7405c41c33927bd2c2527824dc80
parentDecorators: clean up imports (diff)
Decorators: create helper function to get arg value
-rw-r--r--bot/decorators.py34
1 files changed, 24 insertions, 10 deletions
diff --git a/bot/decorators.py b/bot/decorators.py
index d9e5e3a83..1fe082b6e 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -14,6 +14,8 @@ from bot.utils.checks import in_whitelist_check, with_role_check, without_role_c
log = logging.getLogger(__name__)
+Argument = t.Union[int, str]
+
def in_whitelist(
*,
@@ -138,7 +140,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N
return wrap
-def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
+def respect_role_hierarchy(target_arg: Argument = 0) -> t.Callable:
"""
Ensure the highest role of the invoking member is greater than that of the target member.
@@ -153,15 +155,7 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
def wrap(func: t.Callable) -> t.Callable:
@wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
- try:
- target = kwargs[target_arg]
- except KeyError:
- try:
- target = args[target_arg]
- except IndexError:
- raise ValueError(f"Could not find target argument at position {target_arg}")
- except TypeError:
- raise ValueError(f"Could not find target kwarg with key {target_arg!r}")
+ target = _get_arg_value(target_arg, args, kwargs)
if not isinstance(target, Member):
log.trace("The target is not a discord.Member; skipping role hierarchy check.")
@@ -183,3 +177,23 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable:
await func(self, ctx, *args, **kwargs)
return inner
return wrap
+
+
+def _get_arg_value(target_arg: Argument, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> t.Any:
+ """
+ Return the value of the arg at the given position or name `target_arg`.
+
+ Use an integer as a position if the target argument is positional.
+ Use a string as a parameter name if the target argument is a keyword argument.
+
+ Raise ValueError if `target_arg` cannot be found.
+ """
+ try:
+ return kwargs[target_arg]
+ except KeyError:
+ try:
+ return args[target_arg]
+ except IndexError:
+ raise ValueError(f"Could not find target argument at position {target_arg}")
+ except TypeError:
+ raise ValueError(f"Could not find target kwarg with key {target_arg!r}")