diff options
| -rw-r--r-- | bot/constants.py | 10 | ||||
| -rw-r--r-- | bot/exts/moderation/silence.py | 287 | ||||
| -rw-r--r-- | config-default.yml | 14 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 405 | ||||
| -rw-r--r-- | tests/helpers.py | 25 | 
5 files changed, 638 insertions, 103 deletions
| diff --git a/bot/constants.py b/bot/constants.py index cc3aa41a5..916ae77e6 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -445,15 +445,17 @@ class Channels(metaclass=YAMLGetter):      staff_announcements: int      admins_voice: int +    code_help_voice_0: int      code_help_voice_1: int -    code_help_voice_2: int -    general_voice: int +    general_voice_0: int +    general_voice_1: int      staff_voice: int +    code_help_chat_0: int      code_help_chat_1: int -    code_help_chat_2: int      staff_voice_chat: int -    voice_chat: int +    voice_chat_0: int +    voice_chat_1: int      big_brother_logs: int      talent_pool: int diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 2a7ca932e..616dfbefb 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -2,35 +2,44 @@ import json  import logging  from contextlib import suppress  from datetime import datetime, timedelta, timezone -from operator import attrgetter -from typing import Optional +from typing import Optional, OrderedDict, Union  from async_rediscache import RedisCache -from discord import TextChannel +from discord import Guild, PermissionOverwrite, TextChannel, VoiceChannel  from discord.ext import commands, tasks  from discord.ext.commands import Context +from bot import constants  from bot.bot import Bot -from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles  from bot.converters import HushDurationConverter -from bot.utils.lock import LockedResourceError, lock_arg +from bot.utils.lock import LockedResourceError, lock, lock_arg  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__)  LOCK_NAMESPACE = "silence" -MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." -MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." -MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." +MSG_SILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{constants.Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{constants.Emojis.check_mark} silenced current channel for {{duration}} minute(s)." -MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel was not silenced."  MSG_UNSILENCE_MANUAL = ( -    f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " +    f"{constants.Emojis.cross_mark} current channel was not unsilenced because the current overwrites were "      f"set manually or the cache was prematurely cleared. "      f"Please edit the overwrites manually to unsilence."  ) -MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." +MSG_UNSILENCE_SUCCESS = f"{constants.Emojis.check_mark} unsilenced current channel." + +TextOrVoiceChannel = Union[TextChannel, VoiceChannel] + +VOICE_CHANNELS = { +    constants.Channels.code_help_voice_0: constants.Channels.code_help_chat_0, +    constants.Channels.code_help_voice_1: constants.Channels.code_help_chat_1, +    constants.Channels.general_voice_0: constants.Channels.voice_chat_0, +    constants.Channels.general_voice_1: constants.Channels.voice_chat_1, +    constants.Channels.staff_voice: constants.Channels.staff_voice_chat, +}  class SilenceNotifier(tasks.Loop): @@ -41,7 +50,7 @@ class SilenceNotifier(tasks.Loop):          self._silenced_channels = {}          self._alert_channel = alert_channel -    def add_channel(self, channel: TextChannel) -> None: +    def add_channel(self, channel: TextOrVoiceChannel) -> None:          """Add channel to `_silenced_channels` and start loop if not launched."""          if not self._silenced_channels:              self.start() @@ -68,7 +77,19 @@ class SilenceNotifier(tasks.Loop):                  f"{channel.mention} for {(self._current_loop-start)//60} min"                  for channel, start in self._silenced_channels.items()              ) -            await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") +            await self._alert_channel.send( +                f"<@&{constants.Roles.moderators}> currently silenced channels: {channels_text}" +            ) + + +async def _select_lock_channel(args: OrderedDict[str, any]) -> TextOrVoiceChannel: +    """Passes the channel to be silenced to the resource lock.""" +    channel = args["channel"] +    if channel is not None: +        return channel + +    else: +        return args["ctx"].channel  class Silence(commands.Cog): @@ -92,88 +113,230 @@ class Silence(commands.Cog):          """Set instance attributes once the guild is available and reschedule unsilences."""          await self.bot.wait_until_guild_available() -        guild = self.bot.get_guild(Guild.id) +        guild = self.bot.get_guild(constants.Guild.id) +          self._everyone_role = guild.default_role -        self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) -        self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) +        self._verified_voice_role = guild.get_role(constants.Roles.voice_verified) + +        self._mod_alerts_channel = self.bot.get_channel(constants.Channels.mod_alerts) + +        self.notifier = SilenceNotifier(self.bot.get_channel(constants.Channels.mod_log))          await self._reschedule() +    async def send_message( +        self, +        message: str, +        source_channel: TextChannel, +        target_channel: TextOrVoiceChannel, +        *, alert_target: bool = False +    ) -> None: +        """Helper function to send message confirmation to `source_channel`, and notification to `target_channel`.""" +        # Reply to invocation channel +        source_reply = message +        if source_channel != target_channel: +            source_reply = source_reply.replace("current channel", target_channel.mention) +        await source_channel.send(source_reply) + +        # Reply to target channel +        if alert_target: +            if isinstance(target_channel, VoiceChannel): +                voice_chat = self.bot.get_channel(VOICE_CHANNELS.get(target_channel.id)) +                if voice_chat and source_channel != voice_chat: +                    await voice_chat.send(message.replace("current channel", target_channel.mention)) + +            elif source_channel != target_channel: +                await target_channel.send(message) +      @commands.command(aliases=("hush",)) -    @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True) -    async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: +    @lock(LOCK_NAMESPACE, _select_lock_channel, raise_error=True) +    async def silence( +        self, +        ctx: Context, +        duration: HushDurationConverter = 10, +        channel: TextOrVoiceChannel = None, +        *, kick: bool = False +    ) -> None:          """          Silence the current channel for `duration` minutes or `forever`.          Duration is capped at 15 minutes, passing forever makes the silence indefinite.          Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. + +        Passing a voice channel will attempt to move members out of the channel and back to force sync permissions. +        If `kick` is True, members will not be added back to the voice channel, and members will be unable to rejoin.          """          await self._init_task - -        channel_info = f"#{ctx.channel} ({ctx.channel.id})" +        if channel is None: +            channel = ctx.channel +        channel_info = f"#{channel} ({channel.id})"          log.debug(f"{ctx.author} is silencing channel {channel_info}.") -        if not await self._set_silence_overwrites(ctx.channel): +        if not await self._set_silence_overwrites(channel, kick=kick):              log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") -            await ctx.send(MSG_SILENCE_FAIL) +            await self.send_message(MSG_SILENCE_FAIL, ctx.channel, channel, alert_target=False)              return -        await self._schedule_unsilence(ctx, duration) +        if isinstance(channel, VoiceChannel): +            if kick: +                await self._kick_voice_members(channel) +            else: +                await self._force_voice_sync(channel) + +        await self._schedule_unsilence(ctx, channel, duration)          if duration is None: -            self.notifier.add_channel(ctx.channel) +            self.notifier.add_channel(channel)              log.info(f"Silenced {channel_info} indefinitely.") -            await ctx.send(MSG_SILENCE_PERMANENT) +            await self.send_message(MSG_SILENCE_PERMANENT, ctx.channel, channel, alert_target=True) +          else:              log.info(f"Silenced {channel_info} for {duration} minute(s).") -            await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration)) +            formatted_message = MSG_SILENCE_SUCCESS.format(duration=duration) +            await self.send_message(formatted_message, ctx.channel, channel, alert_target=True)      @commands.command(aliases=("unhush",)) -    async def unsilence(self, ctx: Context) -> None: +    async def unsilence(self, ctx: Context, *, channel: TextOrVoiceChannel = None) -> None:          """ -        Unsilence the current channel. +        Unsilence the given channel if given, else the current one.          If the channel was silenced indefinitely, notifications for the channel will stop.          """          await self._init_task -        log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") -        await self._unsilence_wrapper(ctx.channel) +        if channel is None: +            channel = ctx.channel +        log.debug(f"Unsilencing channel #{channel} from {ctx.author}'s command.") +        await self._unsilence_wrapper(channel, ctx)      @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) -    async def _unsilence_wrapper(self, channel: TextChannel) -> None: -        """Unsilence `channel` and send a success/failure message.""" +    async def _unsilence_wrapper(self, channel: TextOrVoiceChannel, ctx: Optional[Context] = None) -> None: +        """ +        Unsilence `channel` and send a success/failure message to ctx.channel. + +        If ctx is None or not passed, `channel` is used in its place. +        If `channel` and ctx.channel are the same, only one message is sent. +        """ +        msg_channel = channel +        if ctx is not None: +            msg_channel = ctx.channel +          if not await self._unsilence(channel): -            overwrite = channel.overwrites_for(self._everyone_role) -            if overwrite.send_messages is False or overwrite.add_reactions is False: -                await channel.send(MSG_UNSILENCE_MANUAL) +            if isinstance(channel, VoiceChannel): +                overwrite = channel.overwrites_for(self._verified_voice_role) +                manual = overwrite.speak is False              else: -                await channel.send(MSG_UNSILENCE_FAIL) +                overwrite = channel.overwrites_for(self._everyone_role) +                manual = overwrite.send_messages is False or overwrite.add_reactions is False + +            # Send fail message to muted channel or voice chat channel, and invocation channel +            if manual: +                await self.send_message(MSG_UNSILENCE_MANUAL, msg_channel, channel, alert_target=False) +            else: +                await self.send_message(MSG_UNSILENCE_FAIL, msg_channel, channel, alert_target=False) +          else: -            await channel.send(MSG_UNSILENCE_SUCCESS) +            await self.send_message(MSG_UNSILENCE_SUCCESS, msg_channel, channel, alert_target=True) -    async def _set_silence_overwrites(self, channel: TextChannel) -> bool: +    async def _set_silence_overwrites(self, channel: TextOrVoiceChannel, *, kick: bool = False) -> bool:          """Set silence permission overwrites for `channel` and return True if successful.""" -        overwrite = channel.overwrites_for(self._everyone_role) -        prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) +        # Get the original channel overwrites +        if isinstance(channel, TextChannel): +            role = self._everyone_role +            overwrite = channel.overwrites_for(role) +            prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) +        else: +            role = self._verified_voice_role +            overwrite = channel.overwrites_for(role) +            prev_overwrites = dict(speak=overwrite.speak) +            if kick: +                prev_overwrites.update(connect=overwrite.connect) + +        # Stop if channel was already silenced          if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):              return False -        overwrite.update(send_messages=False, add_reactions=False) -        await channel.set_permissions(self._everyone_role, overwrite=overwrite) +        # Set new permissions, store +        overwrite.update(**dict.fromkeys(prev_overwrites, False)) +        await channel.set_permissions(role, overwrite=overwrite)          await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites))          return True -    async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: +    @staticmethod +    async def _get_afk_channel(guild: Guild) -> VoiceChannel: +        """Get a guild's AFK channel, or create one if it does not exist.""" +        afk_channel = guild.afk_channel + +        if afk_channel is None: +            overwrites = { +                guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) +            } +            afk_channel = await guild.create_voice_channel("mute-temp", overwrites=overwrites) +            log.info(f"Failed to get afk-channel, created temporary channel #{afk_channel} ({afk_channel.id})") + +        return afk_channel + +    @staticmethod +    async def _kick_voice_members(channel: VoiceChannel) -> None: +        """Remove all non-staff members from a voice channel.""" +        log.debug(f"Removing all non staff members from #{channel.name} ({channel.id}).") + +        for member in channel.members: +            # Skip staff +            if any(role.id in constants.MODERATION_ROLES for role in member.roles): +                continue + +            try: +                await member.move_to(None, reason="Kicking member from voice channel.") +                log.debug(f"Kicked {member.name} from voice channel.") +            except Exception as e: +                log.debug(f"Failed to move {member.name}. Reason: {e}") +                continue + +        log.debug("Removed all members.") + +    async def _force_voice_sync(self, channel: VoiceChannel) -> None: +        """ +        Move all non-staff members from `channel` to a temporary channel and back to force toggle role mute. + +        Permission modification has to happen before this function. +        """ +        # Obtain temporary channel +        delete_channel = channel.guild.afk_channel is None +        afk_channel = await self._get_afk_channel(channel.guild) + +        try: +            # Move all members to temporary channel and back +            for member in channel.members: +                # Skip staff +                if any(role.id in constants.MODERATION_ROLES for role in member.roles): +                    continue + +                try: +                    await member.move_to(afk_channel, reason="Muting VC member.") +                    log.debug(f"Moved {member.name} to afk channel.") + +                    await member.move_to(channel, reason="Muting VC member.") +                    log.debug(f"Moved {member.name} to original voice channel.") +                except Exception as e: +                    log.debug(f"Failed to move {member.name}. Reason: {e}") +                    continue + +        finally: +            # Delete VC channel if it was created. +            if delete_channel: +                await afk_channel.delete(reason="Deleting temporary mute channel.") + +    async def _schedule_unsilence(self, ctx: Context, channel: TextOrVoiceChannel, duration: Optional[int]) -> None:          """Schedule `ctx.channel` to be unsilenced if `duration` is not None."""          if duration is None: -            await self.unsilence_timestamps.set(ctx.channel.id, -1) +            await self.unsilence_timestamps.set(channel.id, -1)          else: -            self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) +            self.scheduler.schedule_later(duration * 60, channel.id, ctx.invoke(self.unsilence, channel=channel))              unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) -            await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) +            await self.unsilence_timestamps.set(channel.id, unsilence_time.timestamp()) -    async def _unsilence(self, channel: TextChannel) -> bool: +    async def _unsilence(self, channel: TextOrVoiceChannel) -> bool:          """          Unsilence `channel`. @@ -183,19 +346,34 @@ class Silence(commands.Cog):          Return `True` if channel permissions were changed, `False` otherwise.          """ +        # Get stored overwrites, and return if channel is unsilenced          prev_overwrites = await self.previous_overwrites.get(channel.id)          if channel.id not in self.scheduler and prev_overwrites is None:              log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")              return False -        overwrite = channel.overwrites_for(self._everyone_role) +        # Select the role based on channel type, and get current overwrites +        if isinstance(channel, TextChannel): +            role = self._everyone_role +            overwrite = channel.overwrites_for(role) +            permissions = "`Send Messages` and `Add Reactions`" +        else: +            role = self._verified_voice_role +            overwrite = channel.overwrites_for(role) +            permissions = "`Speak` and `Connect`" + +        # Check if old overwrites were not stored          if prev_overwrites is None:              log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") -            overwrite.update(send_messages=None, add_reactions=None) +            overwrite.update(send_messages=None, add_reactions=None, speak=None, connect=None)          else:              overwrite.update(**json.loads(prev_overwrites)) -        await channel.set_permissions(self._everyone_role, overwrite=overwrite) +        # Update Permissions +        await channel.set_permissions(role, overwrite=overwrite) +        if isinstance(channel, VoiceChannel): +            await self._force_voice_sync(channel) +          log.info(f"Unsilenced channel #{channel} ({channel.id}).")          self.scheduler.cancel(channel.id) @@ -203,11 +381,12 @@ class Silence(commands.Cog):          await self.previous_overwrites.delete(channel.id)          await self.unsilence_timestamps.delete(channel.id) +        # Alert Admin team if old overwrites were not available          if prev_overwrites is None:              await self._mod_alerts_channel.send( -                f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " -                f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " -                f"overwrites for {self._everyone_role.mention} are at their desired values." +                f"<@&{constants.Roles.admins}> Restored overwrites with default values after unsilencing " +                f"{channel.mention}. Please check that the {permissions} " +                f"overwrites for {role.mention} are at their desired values."              )          return True @@ -247,7 +426,7 @@ class Silence(commands.Cog):      # This cannot be static (must have a __func__ attribute).      async def cog_check(self, ctx: Context) -> bool:          """Only allow moderators to invoke the commands in this cog.""" -        return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx) +        return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx)  def setup(bot: Bot) -> None: diff --git a/config-default.yml b/config-default.yml index b7c446889..9f4c9b80b 100644 --- a/config-default.yml +++ b/config-default.yml @@ -207,16 +207,18 @@ guild:          # Voice Channels          admins_voice:       &ADMINS_VOICE   500734494840717332 -        code_help_voice_1:                  751592231726481530 -        code_help_voice_2:                  764232549840846858 -        general_voice:                      751591688538947646 +        code_help_voice_0:                  751592231726481530 +        code_help_voice_1:                  764232549840846858 +        general_voice_0:                    751591688538947646 +        general_voice_1:                    799641437645701151          staff_voice:        &STAFF_VOICE    412375055910043655          # Voice Chat -        code_help_chat_1:                   755154969761677312 -        code_help_chat_2:                   766330079135268884 +        code_help_chat_0:                   755154969761677312 +        code_help_chat_1:                   766330079135268884          staff_voice_chat:                   541638762007101470 -        voice_chat:                         412357430186344448 +        voice_chat_0:                       412357430186344448 +        voice_chat_1:                       799647045886541885          # Watch          big_brother_logs:   &BB_LOGS        468507907357409333 diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index fa5fc9e81..d7542c562 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,15 +1,25 @@  import asyncio  import unittest  from datetime import datetime, timezone +from typing import List, Tuple  from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock  from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Guild, Roles +from bot.constants import Channels, Guild, MODERATION_ROLES, Roles  from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import ( +    MockBot, +    MockContext, +    MockGuild, +    MockMember, +    MockRole, +    MockTextChannel, +    MockVoiceChannel, +    autospec +)  redis_session = None  redis_loop = asyncio.get_event_loop() @@ -149,7 +159,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):          self.assertTrue(self.cog._init_task.cancelled())      @autospec("discord.ext.commands", "has_any_role") -    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3))      async def test_cog_check(self, role_check):          """Role check was called with `MODERATION_ROLES`"""          ctx = MockContext() @@ -159,6 +169,97 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):          role_check.assert_called_once_with(*(1, 2, 3))          role_check.return_value.predicate.assert_awaited_once_with(ctx) +    async def test_force_voice_sync(self): +        """Tests the _force_voice_sync helper function.""" +        await self.cog._async_init() + +        members = [MockMember() for _ in range(10)] +        members.extend([MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES]) + +        channel = MockVoiceChannel(members=members) + +        await self.cog._force_voice_sync(channel) +        for member in members: +            if any(role.id in MODERATION_ROLES for role in member.roles): +                member.move_to.assert_not_called() +            else: +                self.assertEqual(member.move_to.call_count, 2) +                calls = member.move_to.call_args_list + +                # Tests that the member was moved to the afk channel, and back. +                self.assertEqual((channel.guild.afk_channel,), calls[0].args) +                self.assertEqual((channel,), calls[1].args) + +    async def test_force_voice_sync_no_channel(self): +        """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" +        await self.cog._async_init() + +        channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) +        new_channel = MockVoiceChannel(delete=AsyncMock()) +        channel.guild.create_voice_channel.return_value = new_channel + +        await self.cog._force_voice_sync(channel) + +        # Check channel creation +        overwrites = { +            channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) +        } +        channel.guild.create_voice_channel.assert_called_once_with("mute-temp", overwrites=overwrites) + +        # Check bot deleted channel +        new_channel.delete.assert_called_once() + +    async def test_voice_kick(self): +        """Test to ensure kick function can remove all members from a voice channel.""" +        await self.cog._async_init() + +        members = [MockMember() for _ in range(10)] +        members.extend([MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES]) + +        channel = MockVoiceChannel(members=members) +        await self.cog._kick_voice_members(channel) + +        for member in members: +            if any(role.id in MODERATION_ROLES for role in member.roles): +                member.move_to.assert_not_called() +            else: +                self.assertEqual((None,), member.move_to.call_args_list[0].args) + +    @staticmethod +    def create_erroneous_members() -> Tuple[List[MockMember], List[MockMember]]: +        """ +        Helper method to generate a list of members that error out on move_to call. + +        Returns the list of erroneous members, +        as well as a list of regular and erroneous members combined, in that order. +        """ +        erroneous_members = [MockMember(move_to=Mock(side_effect=Exception())) for _ in range(5)] + +        members = [] +        for i in range(5): +            members.append(MockMember()) +            members.append(erroneous_members[i]) + +        return erroneous_members, members + +    async def test_kick_move_to_error(self): +        """Test to ensure move_to gets called on all members during kick, even if some fail.""" +        await self.cog._async_init() +        failing_members, members = self.create_erroneous_members() + +        await self.cog._kick_voice_members(MockVoiceChannel(members=members)) +        for member in members: +            member.move_to.assert_called_once() + +    async def test_sync_move_to_error(self): +        """Test to ensure move_to gets called on all members during sync, even if some fail.""" +        await self.cog._async_init() +        failing_members, members = self.create_erroneous_members() + +        await self.cog._force_voice_sync(MockVoiceChannel(members=members)) +        for member in members: +            self.assertEqual(member.move_to.call_count, 1 if member in failing_members else 2) +  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class RescheduleTests(unittest.IsolatedAsyncioTestCase): @@ -235,6 +336,16 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase):          self.cog.notifier.add_channel.assert_not_called() +def voice_sync_helper(function): +    """Helper wrapper to test the sync and kick functions for voice channels.""" +    @autospec(silence.Silence, "_force_voice_sync", "_kick_voice_members", "_set_silence_overwrites") +    async def inner(self, sync, kick, overwrites): +        overwrites.return_value = True +        await function(self, MockContext(), sync, kick) + +    return inner + +  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class SilenceTests(unittest.IsolatedAsyncioTestCase):      """Tests for the silence command and its related helper methods.""" @@ -242,7 +353,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      @autospec(silence.Silence, "_reschedule", pass_mocks=False)      @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot() +        self.bot = MockBot(get_channel=lambda _: MockTextChannel())          self.cog = silence.Silence(self.bot)          self.cog._init_task = asyncio.Future()          self.cog._init_task.set_result(None) @@ -252,23 +363,73 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          asyncio.run(self.cog._async_init())  # Populate instance attributes. -        self.channel = MockTextChannel() -        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) -        self.channel.overwrites_for.return_value = self.overwrite +        self.text_channel = MockTextChannel() +        self.text_overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) +        self.text_channel.overwrites_for.return_value = self.text_overwrite + +        self.voice_channel = MockVoiceChannel() +        self.voice_overwrite = PermissionOverwrite(speak=True) +        self.voice_channel.overwrites_for.return_value = self.voice_overwrite      async def test_sent_correct_message(self): -        """Appropriate failure/success message was sent by the command.""" +        """Appropriate failure/success message was sent by the command to the correct channel.""" +        # The following test tuples are made up of: +        # duration, expected message, and the success of the _set_silence_overwrites function          test_cases = (              (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),              (None, silence.MSG_SILENCE_PERMANENT, True,),              (5, silence.MSG_SILENCE_FAIL, False,),          ) +          for duration, message, was_silenced in test_cases: -            ctx = MockContext()              with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): -                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): -                    await self.cog.silence.callback(self.cog, ctx, duration) -                    ctx.send.assert_called_once_with(message) +                for target in [MockTextChannel(), MockVoiceChannel(), None]: +                    with self.subTest(was_silenced=was_silenced, target=target, message=message): +                        with mock.patch.object(self.cog, "send_message") as send_message: +                            ctx = MockContext() +                            await self.cog.silence.callback(self.cog, ctx, duration, target) +                            send_message.assert_called_once_with( +                                message, +                                ctx.channel, +                                target or ctx.channel, +                                alert_target=was_silenced +                            ) + +    @voice_sync_helper +    async def test_sync_called(self, ctx, sync, kick): +        """Tests if silence command calls sync on a voice channel.""" +        channel = MockVoiceChannel() +        await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False) + +        sync.assert_called_once_with(self.cog, channel) +        kick.assert_not_called() + +    @voice_sync_helper +    async def test_kick_called(self, ctx, sync, kick): +        """Tests if silence command calls kick on a voice channel.""" +        channel = MockVoiceChannel() +        await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True) + +        kick.assert_called_once_with(channel) +        sync.assert_not_called() + +    @voice_sync_helper +    async def test_sync_not_called(self, ctx, sync, kick): +        """Tests that silence command does not call sync on a text channel.""" +        channel = MockTextChannel() +        await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False) + +        sync.assert_not_called() +        kick.assert_not_called() + +    @voice_sync_helper +    async def test_kick_not_called(self, ctx, sync, kick): +        """Tests that silence command does not call kick on a text channel.""" +        channel = MockTextChannel() +        await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True) + +        sync.assert_not_called() +        kick.assert_not_called()      async def test_skipped_already_silenced(self):          """Permissions were not set and `False` was returned for an already silenced channel.""" @@ -287,21 +448,39 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):                  self.assertFalse(await self.cog._set_silence_overwrites(channel))                  channel.set_permissions.assert_not_called() -    async def test_silenced_channel(self): +    async def test_silenced_text_channel(self):          """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" -        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) -        self.assertFalse(self.overwrite.send_messages) -        self.assertFalse(self.overwrite.add_reactions) -        self.channel.set_permissions.assert_awaited_once_with( +        self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) +        self.assertFalse(self.text_overwrite.send_messages) +        self.assertFalse(self.text_overwrite.add_reactions) +        self.text_channel.set_permissions.assert_awaited_once_with(              self.cog._everyone_role, -            overwrite=self.overwrite +            overwrite=self.text_overwrite +        ) + +    async def test_silenced_voice_channel_speak(self): +        """Channel had `speak` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel)) +        self.assertFalse(self.voice_overwrite.speak) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite +        ) + +    async def test_silenced_voice_channel_full(self): +        """Channel had `speak` and `connect` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, kick=True)) +        self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite          )      async def test_preserved_other_overwrites(self):          """Channel's other unrelated overwrites were not changed.""" -        prev_overwrite_dict = dict(self.overwrite) -        await self.cog._set_silence_overwrites(self.channel) -        new_overwrite_dict = dict(self.overwrite) +        prev_overwrite_dict = dict(self.text_overwrite) +        await self.cog._set_silence_overwrites(self.text_channel) +        new_overwrite_dict = dict(self.text_overwrite)          # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.          del prev_overwrite_dict['send_messages'] @@ -332,8 +511,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      async def test_cached_previous_overwrites(self):          """Channel's previous overwrites were cached."""          overwrite_json = '{"send_messages": true, "add_reactions": false}' -        await self.cog._set_silence_overwrites(self.channel) -        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) +        await self.cog._set_silence_overwrites(self.text_channel) +        self.cog.previous_overwrites.set.assert_called_once_with(self.text_channel.id, overwrite_json)      @autospec(silence, "datetime")      async def test_cached_unsilence_time(self, datetime_mock): @@ -343,7 +522,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          timestamp = now_timestamp + duration * 60          datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) -        ctx = MockContext(channel=self.channel) +        ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, duration)          self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) @@ -351,23 +530,23 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      async def test_cached_indefinite_time(self):          """A value of -1 was cached for a permanent silence.""" -        ctx = MockContext(channel=self.channel) +        ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, None)          self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)      async def test_scheduled_task(self):          """An unsilence task was scheduled.""" -        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) +        ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock())          await self.cog.silence.callback(self.cog, ctx, 5)          args = (300, ctx.channel.id, ctx.invoke.return_value)          self.cog.scheduler.schedule_later.assert_called_once_with(*args) -        ctx.invoke.assert_called_once_with(self.cog.unsilence) +        ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel)      async def test_permanent_not_scheduled(self):          """A task was not scheduled for a permanent silence.""" -        ctx = MockContext(channel=self.channel) +        ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, None)          self.cog.scheduler.schedule_later.assert_not_called() @@ -405,13 +584,22 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):              (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),              (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),          ) +          for was_unsilenced, message, overwrite in test_cases:              ctx = MockContext() -            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): + +            for target in [None, MockTextChannel()]: +                ctx.channel.overwrites_for.return_value = overwrite +                if target: +                    target.overwrites_for.return_value = overwrite +                  with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): -                    ctx.channel.overwrites_for.return_value = overwrite -                    await self.cog.unsilence.callback(self.cog, ctx) -                    ctx.channel.send.assert_called_once_with(message) +                    with mock.patch.object(self.cog, "send_message") as send_message: +                        with self.subTest(was_unsilenced=was_unsilenced, overwrite=overwrite, target=target): +                            await self.cog.unsilence.callback(self.cog, ctx, channel=target) + +                            call_args = (message, ctx.channel, target or ctx.channel) +                            send_message.assert_called_once_with(*call_args, alert_target=was_unsilenced)      async def test_skipped_already_unsilenced(self):          """Permissions were not set and `False` was returned for an already unsilenced channel.""" @@ -447,13 +635,18 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          self.assertIsNone(self.overwrite.send_messages)          self.assertIsNone(self.overwrite.add_reactions) -    async def test_cache_miss_sent_mod_alert(self): -        """A message was sent to the mod alerts channel.""" +    async def test_cache_miss_sent_mod_alert_text(self): +        """A message was sent to the mod alerts channel upon muting a text channel."""          self.cog.previous_overwrites.get.return_value = None -          await self.cog._unsilence(self.channel)          self.cog._mod_alerts_channel.send.assert_awaited_once() +    async def test_cache_miss_sent_mod_alert_voice(self): +        """A message was sent to the mod alerts channel upon muting a voice channel.""" +        self.cog.previous_overwrites.get.return_value = None +        await self.cog._unsilence(MockVoiceChannel()) +        self.cog._mod_alerts_channel.send.assert_awaited_once() +      async def test_removed_notifier(self):          """Channel was removed from `notifier`."""          await self.cog._unsilence(self.channel) @@ -491,3 +684,143 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):                  del new_overwrite_dict['add_reactions']                  self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_unsilence_role(self): +        """Tests unsilence_wrapper applies permission to the correct role.""" +        test_cases = ( +            (MockTextChannel(), self.cog.bot.get_guild(Guild.id).default_role), +            (MockVoiceChannel(), self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified)) +        ) + +        for channel, role in test_cases: +            with self.subTest(channel=channel, role=role): +                await self.cog._unsilence_wrapper(channel, MockContext()) +                channel.overwrites_for.assert_called_with(role) + +    @mock.patch.object(silence.Silence, "_force_voice_sync") +    @mock.patch.object(silence.Silence, "send_message") +    async def test_correct_overwrites(self, send_message, _): +        """Tests the overwrites returned by the _unsilence_wrapper are correct for voice and text channels.""" +        ctx = MockContext() + +        text_channel = MockTextChannel() +        text_role = self.cog.bot.get_guild(Guild.id).default_role + +        voice_channel = MockVoiceChannel() +        voice_role = self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified) + +        async def reset(): +            await text_channel.set_permissions(text_role, PermissionOverwrite(send_messages=False, add_reactions=False)) +            await voice_channel.set_permissions(voice_role, PermissionOverwrite(speak=False, connect=False)) + +            text_channel.reset_mock() +            voice_channel.reset_mock() +            send_message.reset_mock() +        await reset() + +        default_text_overwrites = text_channel.overwrites_for(text_role) +        default_voice_overwrites = voice_channel.overwrites_for(voice_role) + +        test_cases = ( +            (ctx, text_channel, text_role, default_text_overwrites, silence.MSG_UNSILENCE_SUCCESS), +            (ctx, voice_channel, voice_role, default_voice_overwrites, silence.MSG_UNSILENCE_SUCCESS), +            (ctx, ctx.channel, text_role, ctx.channel.overwrites_for(text_role), silence.MSG_UNSILENCE_SUCCESS), +            (None, text_channel, text_role, default_text_overwrites, silence.MSG_UNSILENCE_SUCCESS), +        ) + +        for context, channel, role, overwrites, message in test_cases: +            with self.subTest(ctx=context, channel=channel): +                await self.cog._unsilence_wrapper(channel, context) + +                if context is None: +                    send_message.assert_called_once_with(message, channel, channel, alert_target=True) +                else: +                    send_message.assert_called_once_with(message, context.channel, channel, alert_target=True) + +                channel.set_permissions.assert_called_once_with(role, overwrite=overwrites) +                if channel != ctx.channel: +                    ctx.channel.send.assert_not_called() + +            await reset() + + +class SendMessageTests(unittest.IsolatedAsyncioTestCase): +    """Unittests for the send message helper function.""" + +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) + +        self.text_channels = [MockTextChannel() for _ in range(2)] +        self.bot.get_channel.return_value = self.text_channels[1] + +        self.voice_channel = MockVoiceChannel() + +    async def test_send_to_channel(self): +        """Tests a basic case for the send function.""" +        message = "Test basic message." +        await self.cog.send_message(message, *self.text_channels, alert_target=False) + +        self.text_channels[0].send.assert_called_once_with(message) +        self.text_channels[1].send.assert_not_called() + +    async def test_send_to_multiple_channels(self): +        """Tests sending messages to two channels.""" +        message = "Test basic message." +        await self.cog.send_message(message, *self.text_channels, alert_target=True) + +        self.text_channels[0].send.assert_called_once_with(message) +        self.text_channels[1].send.assert_called_once_with(message) + +    async def test_duration_replacement(self): +        """Tests that the channel name was set correctly for one target channel.""" +        message = "Current. The following should be replaced: current channel." +        await self.cog.send_message(message, *self.text_channels, alert_target=False) + +        updated_message = message.replace("current channel", self.text_channels[0].mention) +        self.text_channels[0].send.assert_called_once_with(updated_message) +        self.text_channels[1].send.assert_not_called() + +    async def test_name_replacement_multiple_channels(self): +        """Tests that the channel name was set correctly for two channels.""" +        message = "Current. The following should be replaced: current channel." +        await self.cog.send_message(message, *self.text_channels, alert_target=True) + +        updated_message = message.replace("current channel", self.text_channels[0].mention) +        self.text_channels[0].send.assert_called_once_with(updated_message) +        self.text_channels[1].send.assert_called_once_with(message) + +    async def test_silence_voice(self): +        """Tests that the correct message was sent when a voice channel is muted without alerting.""" +        message = "This should show up just here." +        await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=False) +        self.text_channels[0].send.assert_called_once_with(message) +        self.text_channels[1].send.assert_not_called() + +    async def test_silence_voice_alert(self): +        """Tests that the correct message was sent when a voice channel is muted with alerts.""" +        with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: +            mock_voice_channels.get.return_value = self.text_channels[1].id + +            message = "This should show up as current channel." +            await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=True) + +        updated_message = message.replace("current channel", self.voice_channel.mention) +        self.text_channels[0].send.assert_called_once_with(updated_message) +        self.text_channels[1].send.assert_called_once_with(updated_message) + +        mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + +    async def test_silence_voice_sibling_channel(self): +        """Tests silencing a voice channel from the related text channel.""" +        with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: +            mock_voice_channels.get.return_value = self.text_channels[1].id + +            message = "This should show up as current channel." +            await self.cog.send_message(message, self.text_channels[1], self.voice_channel, alert_target=True) + +            updated_message = message.replace("current channel", self.voice_channel.mention) +            self.text_channels[1].send.assert_called_once_with(updated_message) + +            mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) +            self.bot.get_channel.assert_called_once_with(self.text_channels[1].id) diff --git a/tests/helpers.py b/tests/helpers.py index 496363ae3..529664e67 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient  from bot.bot import Bot  from tests._autospec import autospec  # noqa: F401 other modules import it via this module -  for logger in logging.Logger.manager.loggerDict.values():      # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = {  }  state = unittest.mock.MagicMock()  guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data)  class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ -    spec_set = channel_instance +    spec_set = text_channel_instance + +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} +        super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + +        if 'mention' not in kwargs: +            self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock VoiceChannel objects. + +    Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ +    spec_set = voice_channel_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} | 
