aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar S. Co1 <[email protected]>2019-10-25 09:19:39 -0400
committerGravatar GitHub <[email protected]>2019-10-25 09:19:39 -0400
commit586aebb3593b448b65efdfe7a8759e834b6bf407 (patch)
tree190628d6b35972073a014e8ab4264e1402e7991e
parentMerge pull request #550 from python-discord/###-filtering-devtest (diff)
parentMerge branch 'master' into moderator-channel-check (diff)
Merge pull request #543 from atmishra/moderator-channel-check
Restrict ModManagement commands to moderation channels
-rw-r--r--bot/cogs/moderation/management.py10
-rw-r--r--bot/constants.py3
-rw-r--r--bot/utils/checks.py6
-rw-r--r--config-default.yml2
-rw-r--r--tests/bot/utils/test_checks.py8
5 files changed, 22 insertions, 7 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index 491f6d400..44a508436 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -11,7 +11,7 @@ from bot import constants
from bot.converters import InfractionSearchQuery
from bot.pagination import LinePaginator
from bot.utils import time
-from bot.utils.checks import with_role_check
+from bot.utils.checks import in_channel_check, with_role_check
from . import utils
from .infractions import Infractions
from .modlog import ModLog
@@ -256,8 +256,12 @@ class ModManagement(commands.Cog):
# This cannot be static (must have a __func__ attribute).
def cog_check(self, ctx: Context) -> bool:
- """Only allow moderators to invoke the commands in this cog."""
- return with_role_check(ctx, *constants.MODERATION_ROLES)
+ """Only allow moderators from moderator channels to invoke the commands in this cog."""
+ checks = [
+ with_role_check(ctx, *constants.MODERATION_ROLES),
+ in_channel_check(ctx, *constants.MODERATION_CHANNELS)
+ ]
+ return all(checks)
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
diff --git a/bot/constants.py b/bot/constants.py
index f341fb499..60fc1b723 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -515,6 +515,9 @@ STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner
# Roles combinations
STAFF_CHANNELS = Guild.staff_channels
+# Default Channel combinations
+MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam
+
# Bot replies
NEGATIVE_REPLIES = [
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index ad892e512..db56c347c 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -38,9 +38,9 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:
return check
-def in_channel_check(ctx: Context, channel_id: int) -> bool:
- """Checks if the command was executed inside of the specified channel."""
- check = ctx.channel.id == channel_id
+def in_channel_check(ctx: Context, *channel_ids: int) -> bool:
+ """Checks if the command was executed inside the list of specified channels."""
+ check = ctx.channel.id in channel_ids
log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
f"The result of the in_channel check was {check}.")
return check
diff --git a/config-default.yml b/config-default.yml
index 23dcbd44c..8e86234ac 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -109,7 +109,7 @@ guild:
helpers: &HELPERS 385474242440986624
message_log: &MESSAGE_LOG 467752170159079424
meta: 429409067623251969
- mod_spam: &MOD_SPAM 620607373828030464
+ mod_spam: &MOD_SPAM 620607373828030464
mods: &MODS 305126844661760000
mod_alerts: 473092532147060736
modlog: &MODLOG 282638479504965634
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 22dc93073..19b758336 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -41,3 +41,11 @@ class ChecksTests(unittest.TestCase):
role_id = 42
self.ctx.author.roles.append(MockRole(role_id=role_id))
self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
+
+ def test_in_channel_check_for_correct_channel(self):
+ self.ctx.channel.id = 42
+ self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+
+ def test_in_channel_check_for_incorrect_channel(self):
+ self.ctx.channel.id = 42 + 10
+ self.assertFalse(checks.in_channel_check(self.ctx, *[42]))