diff options
Diffstat (limited to '')
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 136 | ||||
| -rw-r--r-- | tests/helpers.py | 14 | 
2 files changed, 83 insertions, 67 deletions
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index a7f239d7f..86d396afd 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -37,15 +37,13 @@ class SilenceTest(RedisTestCase):          self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id))          self.cog = silence.Silence(self.bot) -    @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def asyncSetUp(self) -> None:          await super().asyncSetUp()          await self.cog.cog_load()  # Populate instance attributes. -class SilenceNotifierTests(SilenceTest): +class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None: -        super().setUp()          self.alert_channel = MockTextChannel()          self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock() @@ -54,32 +52,36 @@ class SilenceNotifierTests(SilenceTest):      def test_add_channel_adds_channel(self):          """Channel is added to `_silenced_channels` with the current loop."""          channel = Mock() -        with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: -            self.notifier.add_channel(channel) -        silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop) +        self.notifier.add_channel(channel) +        self.assertDictEqual(self.notifier._silenced_channels, {channel: self.notifier._current_loop}) + +    def test_add_channel_loop_called_correctly(self): +        """Loop is called only in correct scenarios.""" -    def test_add_channel_starts_loop(self): -        """Loop is started if `_silenced_channels` was empty.""" +        # Loop is started if `_silenced_channels` was empty.          self.notifier.add_channel(Mock())          self.notifier_start_mock.assert_called_once() -    def test_add_channel_skips_start_with_channels(self): -        """Loop start is not called when `_silenced_channels` is not empty.""" -        with mock.patch.object(self.notifier, "_silenced_channels"): -            self.notifier.add_channel(Mock()) +        self.notifier_start_mock.reset_mock() + +        # Loop start is not called when `_silenced_channels` is not empty. +        self.notifier.add_channel(Mock())          self.notifier_start_mock.assert_not_called()      def test_remove_channel_removes_channel(self):          """Channel is removed from `_silenced_channels`."""          channel = Mock() -        with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels: -            self.notifier.remove_channel(channel) -        silenced_channels.__delitem__.assert_called_with(channel) +        self.notifier.add_channel(channel) +        self.notifier.remove_channel(channel) +        self.assertDictEqual(self.notifier._silenced_channels, {})      def test_remove_channel_stops_loop(self):          """Notifier loop is stopped if `_silenced_channels` is empty after remove.""" -        with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False): -            self.notifier.remove_channel(Mock()) +        channel = Mock() +        self.notifier.add_channel(channel) +        self.notifier_stop_mock.assert_not_called() + +        self.notifier.remove_channel(channel)          self.notifier_stop_mock.assert_called_once()      def test_remove_channel_skips_stop_with_channels(self): @@ -111,33 +113,28 @@ class SilenceNotifierTests(SilenceTest):                  self.alert_channel.send.assert_not_called() -@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class SilenceCogTests(SilenceTest):      """Tests for the general functionality of the Silence cog.""" -    @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_guild(self):          """Bot got guild after it became available."""          self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id) -    @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def test_cog_load_got_channels(self):          """Got channels from bot.""" -        await self.cog.cog_load()          self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) -    @autospec(silence, "SilenceNotifier") -    async def test_cog_load_got_notifier(self, notifier): +    async def test_cog_load_got_notifier(self):          """Notifier was started with channel.""" -        await self.cog.cog_load() +        with mock.patch.object(silence, "SilenceNotifier") as notifier: +            await self.cog.cog_load()          notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))          self.assertEqual(self.cog.notifier, notifier.return_value) -    @autospec(silence, "SilenceNotifier", pass_mocks=False)      async def testcog_load_rescheduled(self):          """`_reschedule_` coroutine was awaited.""" -        self.cog._reschedule = mock.create_autospec(self.cog._reschedule) +        self.cog._reschedule = AsyncMock()          await self.cog.cog_load()          self.cog._reschedule.assert_awaited_once_with() @@ -242,7 +239,7 @@ class SilenceCogTests(SilenceTest):              self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) -class SilenceArgumentParserTests(SilenceTest): +class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):      """Tests for the silence argument parser utility function."""      @autospec(silence.Silence, "send_message", pass_mocks=False) @@ -250,6 +247,9 @@ class SilenceArgumentParserTests(SilenceTest):      @autospec(silence.Silence, "parse_silence_args")      async def test_command(self, parser_mock):          """Test that the command passes in the correct arguments for different calls.""" +        bot = MockBot() +        cog = silence.Silence(bot) +          test_cases = (              (),              (15, ), @@ -262,7 +262,7 @@ class SilenceArgumentParserTests(SilenceTest):          for case in test_cases:              with self.subTest("Test command converters", args=case): -                await self.cog.silence.callback(self.cog, ctx, *case) +                await cog.silence.callback(cog, ctx, *case)                  try:                      first_arg = case[0] @@ -281,7 +281,7 @@ class SilenceArgumentParserTests(SilenceTest):      async def test_no_arguments(self):          """Test the parser when no arguments are passed to the command."""          ctx = MockContext() -        channel, duration = self.cog.parse_silence_args(ctx, None, 10) +        channel, duration = silence.Silence.parse_silence_args(ctx, None, 10)          self.assertEqual(ctx.channel, channel)          self.assertEqual(10, duration) @@ -289,7 +289,7 @@ class SilenceArgumentParserTests(SilenceTest):      async def test_channel_only(self):          """Test the parser when just the channel argument is passed."""          expected_channel = MockTextChannel() -        actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 10) +        actual_channel, duration = silence.Silence.parse_silence_args(MockContext(), expected_channel, 10)          self.assertEqual(expected_channel, actual_channel)          self.assertEqual(10, duration) @@ -297,7 +297,7 @@ class SilenceArgumentParserTests(SilenceTest):      async def test_duration_only(self):          """Test the parser when just the duration argument is passed."""          ctx = MockContext() -        channel, duration = self.cog.parse_silence_args(ctx, 15, 10) +        channel, duration = silence.Silence.parse_silence_args(ctx, 15, 10)          self.assertEqual(ctx.channel, channel)          self.assertEqual(15, duration) @@ -305,13 +305,12 @@ class SilenceArgumentParserTests(SilenceTest):      async def test_all_args(self):          """Test the parser when both channel and duration are passed."""          expected_channel = MockTextChannel() -        actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 15) +        actual_channel, duration = silence.Silence.parse_silence_args(MockContext(), expected_channel, 15)          self.assertEqual(expected_channel, actual_channel)          self.assertEqual(15, duration) -@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class RescheduleTests(RedisTestCase):      """Tests for the rescheduling of cached unsilences.""" @@ -328,7 +327,7 @@ class RescheduleTests(RedisTestCase):      async def test_skipped_missing_channel(self):          """Did nothing because the channel couldn't be retrieved.""" -        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] +        await self.cog.unsilence_timestamps.set(123, -1)          self.bot.get_channel.return_value = None          await self.cog._reschedule() @@ -341,8 +340,8 @@ class RescheduleTests(RedisTestCase):          """Permanently silenced channels were added to the notifier."""          channels = [MockTextChannel(id=123), MockTextChannel(id=456)]          self.bot.get_channel.side_effect = channels -        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] - +        await self.cog.unsilence_timestamps.set(123, -1) +        await self.cog.unsilence_timestamps.set(456, -1)          await self.cog._reschedule()          self.cog.notifier.add_channel.assert_any_call(channels[0]) @@ -355,7 +354,8 @@ class RescheduleTests(RedisTestCase):          """Unsilenced expired silences."""          channels = [MockTextChannel(id=123), MockTextChannel(id=456)]          self.bot.get_channel.side_effect = channels -        self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] +        await self.cog.unsilence_timestamps.set(123, 100) +        await self.cog.unsilence_timestamps.set(456, 200)          await self.cog._reschedule() @@ -370,7 +370,8 @@ class RescheduleTests(RedisTestCase):          """Rescheduled active silences."""          channels = [MockTextChannel(id=123), MockTextChannel(id=456)]          self.bot.get_channel.side_effect = channels -        self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] +        await self.cog.unsilence_timestamps.set(123, 2000) +        await self.cog.unsilence_timestamps.set(456, 3000)          silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=UTC)          self.cog._unsilence_wrapper = mock.MagicMock() @@ -398,7 +399,6 @@ def voice_sync_helper(function):      return inner -@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class SilenceTests(SilenceTest):      """Tests for the silence command and its related helper methods.""" @@ -596,19 +596,28 @@ class SilenceTests(SilenceTest):      async def test_temp_not_added_to_notifier(self):          """Channel was not added to notifier if a duration was set for the silence.""" -        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +        with ( +            mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True), +            mock.patch.object(self.cog.notifier, "add_channel") +        ):              await self.cog.silence.callback(self.cog, MockContext(), 15)              self.cog.notifier.add_channel.assert_not_called()      async def test_indefinite_added_to_notifier(self):          """Channel was added to notifier if a duration was not set for the silence.""" -        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +        with ( +            mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True), +            mock.patch.object(self.cog.notifier, "add_channel") +        ):              await self.cog.silence.callback(self.cog, MockContext(), None, None)              self.cog.notifier.add_channel.assert_called_once()      async def test_silenced_not_added_to_notifier(self):          """Channel was not added to the notifier if it was already silenced.""" -        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): +        with ( +            mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False), +            mock.patch.object(self.cog.notifier, "add_channel") +        ):              await self.cog.silence.callback(self.cog, MockContext(), 15)              self.cog.notifier.add_channel.assert_not_called() @@ -619,7 +628,7 @@ class SilenceTests(SilenceTest):              '"create_public_threads": false, "send_messages_in_threads": true}'          )          await self.cog._set_silence_overwrites(self.text_channel) -        self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json) +        self.assertEqual(await self.cog.previous_overwrites.get(self.text_channel.id), overwrite_json)      @autospec(silence, "datetime")      async def test_cached_unsilence_time(self, datetime_mock): @@ -632,14 +641,14 @@ class SilenceTests(SilenceTest):          ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, duration) -        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) +        self.assertEqual(await self.cog.unsilence_timestamps.get(ctx.channel.id), timestamp)          datetime_mock.now.assert_called_once_with(tz=UTC)  # Ensure it's using an aware dt.      async def test_cached_indefinite_time(self):          """A value of -1 was cached for a permanent silence."""          ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, None, None) -        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) +        self.assertEqual(await self.cog.unsilence_timestamps.get(ctx.channel.id), -1)      async def test_scheduled_task(self):          """An unsilence task was scheduled.""" @@ -665,7 +674,6 @@ class SilenceTests(SilenceTest):              unsilence.assert_awaited_once_with(ctx, ctx.channel, None) -@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)  class UnsilenceTests(SilenceTest):      """Tests for the unsilence command and its related helper methods.""" @@ -681,13 +689,6 @@ class UnsilenceTests(SilenceTest):          self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)          self.voice_channel.overwrites_for.return_value = self.voice_overwrite -    async def asyncSetUp(self) -> None: -        await super().asyncSetUp() -        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) -        self.cog.previous_overwrites = overwrites_cache - -        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' -      async def test_sent_correct_message(self):          """Appropriate failure/success message was sent by the command."""          unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) @@ -720,7 +721,6 @@ class UnsilenceTests(SilenceTest):      async def test_skipped_already_unsilenced(self):          """Permissions were not set and `False` was returned for an already unsilenced channel."""          self.cog.scheduler.__contains__.return_value = False -        self.cog.previous_overwrites.get.return_value = None          for channel in (MockVoiceChannel(), MockTextChannel()):              with self.subTest(channel=channel): @@ -729,6 +729,7 @@ class UnsilenceTests(SilenceTest):      async def test_restored_overwrites_text(self):          """Text channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog.previous_overwrites.set(self.text_channel.id, '{"send_messages": true, "add_reactions": false}')          await self.cog._unsilence(self.text_channel)          self.text_channel.set_permissions.assert_awaited_once_with(              self.cog._everyone_role, @@ -741,19 +742,18 @@ class UnsilenceTests(SilenceTest):      async def test_restored_overwrites_voice(self):          """Voice channel's `connect` and `speak` overwrites were restored.""" +        await self.cog.previous_overwrites.set(self.voice_channel.id, '{"connect": true, "speak": true}')          await self.cog._unsilence(self.voice_channel)          self.voice_channel.set_permissions.assert_awaited_once_with(              self.cog._verified_voice_role,              overwrite=self.voice_overwrite,          ) -        # Recall that these values are determined by the fixture.          self.assertTrue(self.voice_overwrite.connect)          self.assertTrue(self.voice_overwrite.speak)      async def test_cache_miss_used_default_overwrites_text(self):          """Text overwrites were set to None due previous values not being found in the cache.""" -        self.cog.previous_overwrites.get.return_value = None          await self.cog._unsilence(self.text_channel)          self.text_channel.set_permissions.assert_awaited_once_with( @@ -766,7 +766,6 @@ class UnsilenceTests(SilenceTest):      async def test_cache_miss_used_default_overwrites_voice(self):          """Voice overwrites were set to None due previous values not being found in the cache.""" -        self.cog.previous_overwrites.get.return_value = None          await self.cog._unsilence(self.voice_channel)          self.voice_channel.set_permissions.assert_awaited_once_with( @@ -779,30 +778,31 @@ class UnsilenceTests(SilenceTest):      async def test_cache_miss_sent_mod_alert_text(self):          """A message was sent to the mod alerts channel upon muting a text channel.""" -        self.cog.previous_overwrites.get.return_value = None          await self.cog._unsilence(self.text_channel)          self.cog._mod_alerts_channel.send.assert_awaited_once()      async def test_cache_miss_sent_mod_alert_voice(self):          """A message was sent to the mod alerts channel upon muting a voice channel.""" -        self.cog.previous_overwrites.get.return_value = None          await self.cog._unsilence(MockVoiceChannel())          self.cog._mod_alerts_channel.send.assert_awaited_once()      async def test_removed_notifier(self):          """Channel was removed from `notifier`.""" -        await self.cog._unsilence(self.text_channel) -        self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel) +        with mock.patch.object(silence.SilenceNotifier, "remove_channel"): +            await self.cog._unsilence(self.text_channel) +            self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel)      async def test_deleted_cached_overwrite(self):          """Channel was deleted from the overwrites cache.""" +        await self.cog.previous_overwrites.set(self.text_channel.id, '{"send_messages": true, "add_reactions": false}')          await self.cog._unsilence(self.text_channel) -        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id) +        self.assertEqual(await self.cog.previous_overwrites.get(self.text_channel.id), None)      async def test_deleted_cached_time(self):          """Channel was deleted from the timestamp cache.""" +        await self.cog.unsilence_timestamps.set(self.text_channel.id, 100)          await self.cog._unsilence(self.text_channel) -        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id) +        self.assertEqual(await self.cog.unsilence_timestamps.get(self.text_channel.id), None)      async def test_cancelled_task(self):          """The scheduled unsilence task should be cancelled.""" @@ -813,7 +813,10 @@ class UnsilenceTests(SilenceTest):          """Text channel's other unrelated overwrites were not changed, including cache misses."""          for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):              with self.subTest(overwrite_json=overwrite_json): -                self.cog.previous_overwrites.get.return_value = overwrite_json +                if overwrite_json is None: +                    await self.cog.previous_overwrites.delete(self.text_channel.id) +                else: +                    await self.cog.previous_overwrites.set(self.text_channel.id, overwrite_json)                  prev_overwrite_dict = dict(self.text_overwrite)                  await self.cog._unsilence(self.text_channel) @@ -831,7 +834,10 @@ class UnsilenceTests(SilenceTest):          """Voice channel's other unrelated overwrites were not changed, including cache misses."""          for overwrite_json in ('{"connect": true, "speak": true}', None):              with self.subTest(overwrite_json=overwrite_json): -                self.cog.previous_overwrites.get.return_value = overwrite_json +                if overwrite_json is None: +                    await self.cog.previous_overwrites.delete(self.voice_channel.id) +                else: +                    await self.cog.previous_overwrites.set(self.voice_channel.id, overwrite_json)                  prev_overwrite_dict = dict(self.voice_overwrite)                  await self.cog._unsilence(self.voice_channel) diff --git a/tests/helpers.py b/tests/helpers.py index c51a82a9d..1164828d6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -388,12 +388,17 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      spec_set = text_channel_instance      def __init__(self, **kwargs) -> None: -        default_kwargs = {"id": next(self.discord_id), "name": "channel", "guild": MockGuild()} +        default_kwargs = {"id": next(self.discord_id), "name": "channel"}          super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if "mention" not in kwargs:              self.mention = f"#{self.name}" +    @cached_property +    def guild(self) -> MockGuild: +        """Cached guild property.""" +        return MockGuild() +  class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """ @@ -405,12 +410,17 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      spec_set = voice_channel_instance      def __init__(self, **kwargs) -> None: -        default_kwargs = {"id": next(self.discord_id), "name": "channel", "guild": MockGuild()} +        default_kwargs = {"id": next(self.discord_id), "name": "channel"}          super().__init__(**collections.ChainMap(kwargs, default_kwargs))          if "mention" not in kwargs:              self.mention = f"#{self.name}" +    @cached_property +    def guild(self) -> MockGuild: +        """Cached guild property.""" +        return MockGuild() +  # Create data for the DMChannel instance  state = unittest.mock.MagicMock()  |