aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils/decorators.py
diff options
context:
space:
mode:
authorGravatar Shivansh-007 <[email protected]>2021-04-03 11:19:02 +0530
committerGravatar Shivansh-007 <[email protected]>2021-04-03 11:19:02 +0530
commit1fea8359f72d2ae03dfbd970ff1430ff787a129b (patch)
tree3efdd0b92048164711232c67dbd948e7d287ed86 /bot/utils/decorators.py
parentUse constants for delete delay and remove redundant f-string. (diff)
parentMerge branch 'main' into feature/command-suggestions (diff)
Merge remote-tracking branch 'origin/feature/command-suggestions' into feature/command-suggestions
Diffstat (limited to 'bot/utils/decorators.py')
-rw-r--r--bot/utils/decorators.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index c12a15ff..60066dc4 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -11,7 +11,7 @@ from discord import Colour, Embed
from discord.ext import commands
from discord.ext.commands import CheckFailure, Command, Context
-from bot.constants import ERROR_REPLIES, Month
+from bot.constants import Channels, ERROR_REPLIES, Month, WHITELISTED_CHANNELS
from bot.utils import human_months, resolve_current_month
from bot.utils.checks import in_whitelist_check
@@ -253,6 +253,12 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context],
channels = set(kwargs.get("channels") or {})
categories = kwargs.get("categories")
+ # Only output override channels + community_bot_commands
+ if channels:
+ default_whitelist_channels = set(WHITELISTED_CHANNELS)
+ default_whitelist_channels.discard(Channels.community_bot_commands)
+ channels.difference_update(default_whitelist_channels)
+
# Add all whitelisted category channels
if categories:
for category_id in categories:
@@ -260,7 +266,7 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context],
if category is None:
continue
- [channels.add(channel.id) for channel in category.text_channels]
+ channels.update(channel.id for channel in category.text_channels)
if channels:
channels_str = ', '.join(f"<#{c_id}>" for c_id in channels)