diff options
Diffstat (limited to '')
| -rw-r--r-- | bot/cogs/moderation/management.py | 10 | ||||
| -rw-r--r-- | bot/constants.py | 3 | ||||
| -rw-r--r-- | bot/utils/checks.py | 6 | ||||
| -rw-r--r-- | config-default.yml | 2 | ||||
| -rw-r--r-- | tests/bot/utils/test_checks.py | 8 | 
5 files changed, 22 insertions, 7 deletions
| diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 491f6d400..44a508436 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -11,7 +11,7 @@ from bot import constants  from bot.converters import InfractionSearchQuery  from bot.pagination import LinePaginator  from bot.utils import time -from bot.utils.checks import with_role_check +from bot.utils.checks import in_channel_check, with_role_check  from . import utils  from .infractions import Infractions  from .modlog import ModLog @@ -256,8 +256,12 @@ class ModManagement(commands.Cog):      # This cannot be static (must have a __func__ attribute).      def cog_check(self, ctx: Context) -> bool: -        """Only allow moderators to invoke the commands in this cog.""" -        return with_role_check(ctx, *constants.MODERATION_ROLES) +        """Only allow moderators from moderator channels to invoke the commands in this cog.""" +        checks = [ +            with_role_check(ctx, *constants.MODERATION_ROLES), +            in_channel_check(ctx, *constants.MODERATION_CHANNELS) +        ] +        return all(checks)      # This cannot be static (must have a __func__ attribute).      async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/constants.py b/bot/constants.py index f341fb499..60fc1b723 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -515,6 +515,9 @@ STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner  # Roles combinations  STAFF_CHANNELS = Guild.staff_channels +# Default Channel combinations +MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam +  # Bot replies  NEGATIVE_REPLIES = [ diff --git a/bot/utils/checks.py b/bot/utils/checks.py index ad892e512..db56c347c 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -38,9 +38,9 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:      return check -def in_channel_check(ctx: Context, channel_id: int) -> bool: -    """Checks if the command was executed inside of the specified channel.""" -    check = ctx.channel.id == channel_id +def in_channel_check(ctx: Context, *channel_ids: int) -> bool: +    """Checks if the command was executed inside the list of specified channels.""" +    check = ctx.channel.id in channel_ids      log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "                f"The result of the in_channel check was {check}.")      return check diff --git a/config-default.yml b/config-default.yml index 23dcbd44c..8e86234ac 100644 --- a/config-default.yml +++ b/config-default.yml @@ -109,7 +109,7 @@ guild:          helpers:           &HELPERS       385474242440986624          message_log:       &MESSAGE_LOG   467752170159079424          meta:                             429409067623251969 -        mod_spam:          &MOD_SPAM      620607373828030464     +        mod_spam:          &MOD_SPAM      620607373828030464          mods:              &MODS          305126844661760000          mod_alerts:                       473092532147060736          modlog:            &MODLOG        282638479504965634 diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 22dc93073..19b758336 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -41,3 +41,11 @@ class ChecksTests(unittest.TestCase):          role_id = 42          self.ctx.author.roles.append(MockRole(role_id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + +    def test_in_channel_check_for_correct_channel(self): +        self.ctx.channel.id = 42 +        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) + +    def test_in_channel_check_for_incorrect_channel(self): +        self.ctx.channel.id = 42 + 10 +        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) | 
