diff options
| -rw-r--r-- | bot/exts/moderation/silence.py | 27 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 91 | ||||
| -rw-r--r-- | tests/helpers.py | 25 | 
3 files changed, 126 insertions, 17 deletions
| diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 2d928182a..2aebee9d7 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -104,6 +104,20 @@ class Silence(commands.Cog):          self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log))          await self._reschedule() +    async def _get_related_text_channel(self, channel: VoiceChannel) -> Optional[TextChannel]: +        """Returns the text channel related to a voice channel.""" +        # TODO: Figure out a dynamic way of doing this +        channels = { +            "off-topic": Channels.voice_chat, +            "code/help 1": Channels.code_help_voice, +            "code/help 2": Channels.code_help_voice, +            "admin": Channels.admins_voice, +            "staff": Channels.staff_voice +        } +        for name in channels.keys(): +            if name in channel.name.lower(): +                return self.bot.get_channel(channels[name]) +      async def send_message(self, message: str, source_channel: TextChannel,                             target_channel: Union[TextChannel, VoiceChannel],                             alert_target: bool = False, duration: HushDurationConverter = 0) -> None: @@ -116,18 +130,7 @@ class Silence(commands.Cog):          voice_chat = None          if isinstance(target_channel, VoiceChannel):              # Send to relevant channel -            # TODO: Figure out a dynamic way of doing this -            channels = { -                "offtopic": Channels.voice_chat, -                "code/help 1": Channels.code_help_voice, -                "code/help 2": Channels.code_help_voice, -                "admin": Channels.admins_voice, -                "staff": Channels.staff_voice -            } -            for name in channels.keys(): -                if name in target_channel.name.lower(): -                    voice_chat = self.bot.get_channel(channels[name]) -                    break +            voice_chat = await self._get_related_text_channel(target_channel)          if alert_target and source_channel != target_channel:              if isinstance(target_channel, VoiceChannel): diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index bac933115..577725071 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -7,9 +7,9 @@ from unittest.mock import Mock  from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Guild, Roles +from bot.constants import Channels, Emojis, Guild, Roles  from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import MockBot, MockContext, MockTextChannel, MockVoiceChannel, autospec  redis_session = None  redis_loop = asyncio.get_event_loop() @@ -168,6 +168,93 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):          role_check.assert_called_once_with(*(1, 2, 3))          role_check.return_value.predicate.assert_awaited_once_with(ctx) +    @mock.patch.object(silence.Silence, "_get_related_text_channel") +    async def test_send_message(self, mock_get_related_text_channel): +        """Test the send function reports to the correct channels.""" +        text_channel_1 = MockTextChannel() +        text_channel_2 = MockTextChannel() + +        voice_channel = MockVoiceChannel() +        voice_channel.name = "General/Offtopic" +        voice_channel.mention = f"#{voice_channel.name}" + +        mock_get_related_text_channel.return_value = text_channel_2 + +        def reset(): +            text_channel_1.reset_mock() +            text_channel_2.reset_mock() +            voice_channel.reset_mock() + +        with self.subTest("Basic One Channel Test"): +            await self.cog.send_message("Text basic message.", text_channel_1, text_channel_2, False) +            text_channel_1.send.assert_called_once_with("Text basic message.") +            text_channel_2.send.assert_not_called() + +        reset() +        with self.subTest("Basic Two Channel Test"): +            await self.cog.send_message("Text basic message.", text_channel_1, text_channel_2, True) +            text_channel_1.send.assert_called_once_with("Text basic message.") +            text_channel_2.send.assert_called_once_with("Text basic message.") + +        reset() +        with self.subTest("Replacement One Channel Test"): +            await self.cog.send_message("The following should be replaced: current", +                                        text_channel_1, text_channel_2, False) +            text_channel_1.send.assert_called_once_with(f"The following should be replaced: {text_channel_1.mention}") +            text_channel_2.send.assert_not_called() + +        reset() +        with self.subTest("Replacement Two Channel Test"): +            await self.cog.send_message("The following should be replaced: current", +                                        text_channel_1, text_channel_2, True) +            text_channel_1.send.assert_called_once_with(f"The following should be replaced: {text_channel_1.mention}") +            text_channel_2.send.assert_called_once_with("The following should be replaced: current") + +        reset() +        with self.subTest("Replace Duration"): +            await self.cog.send_message(f"{Emojis.check_mark} The following should be replaced: {{duration}}", +                                        text_channel_1, text_channel_2, False) +            text_channel_1.send.assert_called_once_with(f"{Emojis.check_mark} The following should be replaced: 0") +            text_channel_2.send.assert_not_called() + +        reset() +        with self.subTest("Text and Voice"): +            await self.cog.send_message("This should show up just here", +                                        text_channel_1, voice_channel, False) +            text_channel_1.send.assert_called_once_with("This should show up just here") + +        reset() +        with self.subTest("Text and Voice"): +            await self.cog.send_message("This should show up as current", +                                        text_channel_1, voice_channel, True) +            text_channel_1.send.assert_called_once_with(f"This should show up as {voice_channel.mention}") +            text_channel_2.send.assert_called_once_with(f"This should show up as {voice_channel.mention}") + +        reset() +        with self.subTest("Text and Voice Same Invocation"): +            await self.cog.send_message("This should show up as current", +                                        text_channel_2, voice_channel, True) +            text_channel_2.send.assert_called_once_with(f"This should show up as {voice_channel.mention}") + +    async def test_get_related_text_channel(self): +        voice_channel = MockVoiceChannel() + +        tests = ( +            ("Off-Topic/General", Channels.voice_chat), +            ("code/help 1", Channels.code_help_voice), +            ("Staff", Channels.staff_voice), +            ("ADMIN", Channels.admins_voice), +            ("not in the channel list", None) +        ) + +        with mock.patch.object(self.cog.bot, "get_channel", lambda x: x): +            for (name, channel_id) in tests: +                voice_channel.name = name +                voice_channel.id = channel_id + +                result_id = await self.cog._get_related_text_channel(voice_channel) +                self.assertEqual(result_id, channel_id) +  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class RescheduleTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 870f66197..5628ca31f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient  from bot.bot import Bot  from tests._autospec import autospec  # noqa: F401 other modules import it via this module -  for logger in logging.Logger.manager.loggerDict.values():      # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = {  }  state = unittest.mock.MagicMock()  guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data)  class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ -    spec_set = channel_instance +    spec_set = text_channel_instance + +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} +        super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + +        if 'mention' not in kwargs: +            self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock VoiceChannel objects. + +    Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ +    spec_set = voice_channel_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} | 
