diff options
author | 2021-12-15 20:20:48 -0500 | |
---|---|---|
committer | 2022-01-12 09:26:25 -0500 | |
commit | 2ee5d245406997e171d6694cd0f4de5d49423605 (patch) | |
tree | 7fc64de296051c9f042a150db81eca6747df648d /bot/utils | |
parent | Merge pull request #1008 from python-discord/fix-aoc-join-logic (diff) |
fix: subcommands inherit their parent's whitelist
solves issue with adding decorator to the parent
which wouldn't apply to the children
Diffstat (limited to 'bot/utils')
-rw-r--r-- | bot/utils/decorators.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 7a3b14ad..3688327a 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -199,13 +199,29 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo kwargs = default_kwargs.copy() allow_dms = False - # Update kwargs based on override - if hasattr(ctx.command.callback, "override"): + # determine which command's overrides we will use + # as we have groups, we want to ensure that group commands inherit from the parent + overridden_command: Union[commands.Command, commands.Group] = None + for command in [ctx.command, *ctx.command.parents]: + print(command) + if hasattr(command.callback, "override"): + overridden_command = command + break + if overridden_command is not None: + log.debug(f'Command {overridden_command} has overrides') + if overridden_command is not ctx.command: + log.debug( + f"Command '{ctx.command.qualified_name}' inherited overrides " + "from parent command '{overridden_command.qualified_name}'" + ) + + # Update kwargs based on override, if one exists + if overridden_command and hasattr(overridden_command.callback, "override"): # Handle DM invocations - allow_dms = ctx.command.callback.override_dm + allow_dms = overridden_command.callback.override_dm # Remove default kwargs if reset is True - if ctx.command.callback.override_reset: + if overridden_command.callback.override_reset: kwargs = {} log.debug( f"{ctx.author} called the '{ctx.command.name}' command and " @@ -213,9 +229,9 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo ) # Merge overwrites and defaults - for arg in ctx.command.callback.override: + for arg in overridden_command.callback.override: default_value = kwargs.get(arg) - new_value = ctx.command.callback.override[arg] + new_value = overridden_command.callback.override[arg] # Skip values that don't need merging, or can't be merged if default_value is None or isinstance(arg, int): |