diff options
author | 2019-10-25 09:19:39 -0400 | |
---|---|---|
committer | 2019-10-25 09:19:39 -0400 | |
commit | 586aebb3593b448b65efdfe7a8759e834b6bf407 (patch) | |
tree | 190628d6b35972073a014e8ab4264e1402e7991e | |
parent | Merge pull request #550 from python-discord/###-filtering-devtest (diff) | |
parent | Merge branch 'master' into moderator-channel-check (diff) |
Merge pull request #543 from atmishra/moderator-channel-check
Restrict ModManagement commands to moderation channels
-rw-r--r-- | bot/cogs/moderation/management.py | 10 | ||||
-rw-r--r-- | bot/constants.py | 3 | ||||
-rw-r--r-- | bot/utils/checks.py | 6 | ||||
-rw-r--r-- | config-default.yml | 2 | ||||
-rw-r--r-- | tests/bot/utils/test_checks.py | 8 |
5 files changed, 22 insertions, 7 deletions
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 491f6d400..44a508436 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -11,7 +11,7 @@ from bot import constants from bot.converters import InfractionSearchQuery from bot.pagination import LinePaginator from bot.utils import time -from bot.utils.checks import with_role_check +from bot.utils.checks import in_channel_check, with_role_check from . import utils from .infractions import Infractions from .modlog import ModLog @@ -256,8 +256,12 @@ class ModManagement(commands.Cog): # This cannot be static (must have a __func__ attribute). def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) + """Only allow moderators from moderator channels to invoke the commands in this cog.""" + checks = [ + with_role_check(ctx, *constants.MODERATION_ROLES), + in_channel_check(ctx, *constants.MODERATION_CHANNELS) + ] + return all(checks) # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/constants.py b/bot/constants.py index f341fb499..60fc1b723 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -515,6 +515,9 @@ STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner # Roles combinations STAFF_CHANNELS = Guild.staff_channels +# Default Channel combinations +MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam + # Bot replies NEGATIVE_REPLIES = [ diff --git a/bot/utils/checks.py b/bot/utils/checks.py index ad892e512..db56c347c 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -38,9 +38,9 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool: return check -def in_channel_check(ctx: Context, channel_id: int) -> bool: - """Checks if the command was executed inside of the specified channel.""" - check = ctx.channel.id == channel_id +def in_channel_check(ctx: Context, *channel_ids: int) -> bool: + """Checks if the command was executed inside the list of specified channels.""" + check = ctx.channel.id in channel_ids log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " f"The result of the in_channel check was {check}.") return check diff --git a/config-default.yml b/config-default.yml index 23dcbd44c..8e86234ac 100644 --- a/config-default.yml +++ b/config-default.yml @@ -109,7 +109,7 @@ guild: helpers: &HELPERS 385474242440986624 message_log: &MESSAGE_LOG 467752170159079424 meta: 429409067623251969 - mod_spam: &MOD_SPAM 620607373828030464 + mod_spam: &MOD_SPAM 620607373828030464 mods: &MODS 305126844661760000 mod_alerts: 473092532147060736 modlog: &MODLOG 282638479504965634 diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 22dc93073..19b758336 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -41,3 +41,11 @@ class ChecksTests(unittest.TestCase): role_id = 42 self.ctx.author.roles.append(MockRole(role_id=role_id)) self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + + def test_in_channel_check_for_correct_channel(self): + self.ctx.channel.id = 42 + self.assertTrue(checks.in_channel_check(self.ctx, *[42])) + + def test_in_channel_check_for_incorrect_channel(self): + self.ctx.channel.id = 42 + 10 + self.assertFalse(checks.in_channel_check(self.ctx, *[42])) |