diff options
| -rw-r--r-- | bot/decorators.py | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/bot/decorators.py b/bot/decorators.py index d9e5e3a83..1fe082b6e 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -14,6 +14,8 @@ from bot.utils.checks import in_whitelist_check, with_role_check, without_role_c log = logging.getLogger(__name__) +Argument = t.Union[int, str] + def in_whitelist( *, @@ -138,7 +140,7 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N return wrap -def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable: +def respect_role_hierarchy(target_arg: Argument = 0) -> t.Callable: """ Ensure the highest role of the invoking member is greater than that of the target member. @@ -153,15 +155,7 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable: def wrap(func: t.Callable) -> t.Callable: @wraps(func) async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None: - try: - target = kwargs[target_arg] - except KeyError: - try: - target = args[target_arg] - except IndexError: - raise ValueError(f"Could not find target argument at position {target_arg}") - except TypeError: - raise ValueError(f"Could not find target kwarg with key {target_arg!r}") + target = _get_arg_value(target_arg, args, kwargs) if not isinstance(target, Member): log.trace("The target is not a discord.Member; skipping role hierarchy check.") @@ -183,3 +177,23 @@ def respect_role_hierarchy(target_arg: t.Union[int, str] = 0) -> t.Callable: await func(self, ctx, *args, **kwargs) return inner return wrap + + +def _get_arg_value(target_arg: Argument, args: t.Tuple, kwargs: t.Dict[str, t.Any]) -> t.Any: + """ + Return the value of the arg at the given position or name `target_arg`. + + Use an integer as a position if the target argument is positional. + Use a string as a parameter name if the target argument is a keyword argument. + + Raise ValueError if `target_arg` cannot be found. + """ + try: + return kwargs[target_arg] + except KeyError: + try: + return args[target_arg] + except IndexError: + raise ValueError(f"Could not find target argument at position {target_arg}") + except TypeError: + raise ValueError(f"Could not find target kwarg with key {target_arg!r}") |