aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar onerandomusername <[email protected]>2021-12-15 20:20:48 -0500
committerGravatar onerandomusername <[email protected]>2022-01-12 09:26:25 -0500
commit2ee5d245406997e171d6694cd0f4de5d49423605 (patch)
tree7fc64de296051c9f042a150db81eca6747df648d
parentMerge pull request #1008 from python-discord/fix-aoc-join-logic (diff)
fix: subcommands inherit their parent's whitelist
solves issue with adding decorator to the parent which wouldn't apply to the children
-rw-r--r--bot/utils/decorators.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 7a3b14ad..3688327a 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -199,13 +199,29 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
kwargs = default_kwargs.copy()
allow_dms = False
- # Update kwargs based on override
- if hasattr(ctx.command.callback, "override"):
+ # determine which command's overrides we will use
+ # as we have groups, we want to ensure that group commands inherit from the parent
+ overridden_command: Union[commands.Command, commands.Group] = None
+ for command in [ctx.command, *ctx.command.parents]:
+ print(command)
+ if hasattr(command.callback, "override"):
+ overridden_command = command
+ break
+ if overridden_command is not None:
+ log.debug(f'Command {overridden_command} has overrides')
+ if overridden_command is not ctx.command:
+ log.debug(
+ f"Command '{ctx.command.qualified_name}' inherited overrides "
+ "from parent command '{overridden_command.qualified_name}'"
+ )
+
+ # Update kwargs based on override, if one exists
+ if overridden_command and hasattr(overridden_command.callback, "override"):
# Handle DM invocations
- allow_dms = ctx.command.callback.override_dm
+ allow_dms = overridden_command.callback.override_dm
# Remove default kwargs if reset is True
- if ctx.command.callback.override_reset:
+ if overridden_command.callback.override_reset:
kwargs = {}
log.debug(
f"{ctx.author} called the '{ctx.command.name}' command and "
@@ -213,9 +229,9 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
)
# Merge overwrites and defaults
- for arg in ctx.command.callback.override:
+ for arg in overridden_command.callback.override:
default_value = kwargs.get(arg)
- new_value = ctx.command.callback.override[arg]
+ new_value = overridden_command.callback.override[arg]
# Skip values that don't need merging, or can't be merged
if default_value is None or isinstance(arg, int):