diff options
Diffstat (limited to '')
| -rw-r--r-- | tests/_autospec.py | 64 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 120 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 148 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 587 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 7 | ||||
| -rw-r--r-- | tests/helpers.py | 21 | 
6 files changed, 721 insertions, 226 deletions
| diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): +    """Skips adding patchings as args if their `dont_pass` attribute is True.""" +    # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. +    extra_args = [] +    with contextlib.ExitStack() as exit_stack: +        for patching in patched.patchings: +            arg = exit_stack.enter_context(patching) +            if not getattr(patching, "dont_pass", False): +                # Only add the patching as an arg if dont_pass is False. +                if patching.attribute_name is not None: +                    keywargs.update(arg) +                elif patching.new is unittest.mock.DEFAULT: +                    extra_args.append(arg) + +        args += tuple(extra_args) +        yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): +    """Copy the `dont_pass` attribute along with the standard copy operation.""" +    patcher_copy = _copy.original(self) +    patcher_copy.dont_pass = getattr(self, "dont_pass", False) +    return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: +    """ +    Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + +    If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. +    """ +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = dict(spec_set=True, autospec=True) +    kwargs.update(patch_kwargs) + +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            if not pass_mocks: +                # A custom attribute to keep track of which patchings should be skipped. +                patcher.dont_pass = True +            func = patcher(func) +        return func +    return decorator diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index c0a1da35c..9f380a15d 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,6 @@  import unittest -from unittest import mock -from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff  from tests import helpers @@ -10,7 +9,7 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("roles", [666])      kwargs.setdefault("in_guild", True)      return kwargs @@ -40,22 +39,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          return guild +    @staticmethod +    def get_mock_member(member: dict): +        member = member.copy() +        del member["in_guild"] +        mock_member = helpers.MockMember(**member) +        mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] +        return mock_member +      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [] +        }          guild = self.get_guild()          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical.""" -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.return_value = self.get_mock_member(fake_user())          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -63,59 +82,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(id=99, name="old"), fake_user()] +        }          guild = self.get_guild(updated_user, fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(updated_user), +            self.get_mock_member(fake_user()) +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**updated_user)}, None) +        expected_diff = ([], [{"id": 99, "name": "new"}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_users(self): -        """Only new users should be added to the 'created' set of the diff.""" +        """Only new users should be added to the 'created' list of the diff."""          new_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user(), new_user) - +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(new_user) +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, set(), None) +        expected_diff = ([new_user], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        leaving_user = fake_user(id=63, in_guild=False) - -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**leaving_user)}, None) +        expected_diff = ([], [{"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") +          updated_user = fake_user(id=55, name="updated") -        leaving_user = fake_user(id=63, in_guild=False) -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=55), fake_user(id=63)] +        }          guild = self.get_guild(fake_user(), new_user, updated_user) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(updated_user), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) +        expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_db_users_not_in_guild(self): -        """When the DB knows a user the guild doesn't, no difference is found.""" -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        """When the DB knows a user, but the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63, in_guild=False)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -131,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(user_tuples, set(), None) +        diff = _Diff(users, [], None)          await self.syncer._sync(diff) -        calls = [mock.call("bot/users", json=user) for user in users] -        self.bot.api_client.post.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.post.call_count, len(users)) +        self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() @@ -146,13 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          """Only PUT requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(set(), user_tuples, None) +        diff = _Diff([], users, None)          await self.syncer._sync(diff) -        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] -        self.bot.api_client.put.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.put.call_count, len(users)) +        self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index be1b649e1..bf557a484 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,7 +1,8 @@  import textwrap  import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from bot.constants import Event  from bot.exts.moderation.infraction.infractions import Infractions  from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value          ) + + +@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) +class VoiceBanTests(unittest.IsolatedAsyncioTestCase): +    """Tests for voice ban related functions and commands.""" + +    def setUp(self): +        self.bot = MockBot() +        self.mod = MockMember(top_role=10) +        self.user = MockMember(top_role=1, roles=[MockRole(id=123456)]) +        self.guild = MockGuild() +        self.ctx = MockContext(bot=self.bot, author=self.mod) +        self.cog = Infractions(self.bot) + +    async def test_permanent_voice_ban(self): +        """Should call voice ban applying function without expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar") + +    async def test_temporary_voice_ban(self): +        """Should call voice ban applying function with expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + +    async def test_voice_unban(self): +        """Should call infraction pardoning function.""" +        self.cog.pardon_infraction = AsyncMock() +        self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) +        self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): +        """Should return early when user already have Voice Ban infraction.""" +        get_active_infraction.return_value = {"foo": "bar"} +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") +        post_infraction_mock.assert_not_awaited() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): +        """Should return early when posting infraction fails.""" +        self.cog.mod_log.ignore = MagicMock() +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        post_infraction_mock.assert_awaited_once() +        self.cog.mod_log.ignore.assert_not_called() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): +        """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" +        get_active_infraction.return_value = None +        # We don't want that this continue yet +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) +        post_infraction_mock.assert_awaited_once_with( +            self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 +        ) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar") +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): +        """Should truncate reason for voice ban.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) +        self.user.remove_roles.assert_called_once_with( +            self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...") +        ) +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    async def test_voice_unban_user_not_found(self): +        """Should include info to return dict when user was not found from guild.""" +        self.guild.get_member.return_value = None +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, {"Info": "User was not found in the guild."}) + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = True +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "Sent" +        }) +        notify_pardon_mock.assert_awaited_once() + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = False +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "**Failed**" +        }) +        notify_pardon_mock.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 3c2d52ae0..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio  import unittest +from datetime import datetime, timezone  from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule():  # noqa: N802 +    """Create and connect to the fakeredis session.""" +    global redis_session +    redis_session = RedisSession(use_fakeredis=True) +    redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule():  # noqa: N802 +    """Close the fakeredis session.""" +    if redis_session: +        redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): +    """A datetime object with a mocked now() function.""" + +    now = mock.create_autospec(datetime, "now")  class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.alert_channel = MockTextChannel() -        self.notifier = SilenceNotifier(self.alert_channel) +        self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock()          self.notifier.start = self.notifier_start_mock = Mock()      def test_add_channel_adds_channel(self): -        """Channel in FirstHash with current loop is added to internal set.""" +        """Channel is added to `_silenced_channels` with the current loop."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):          self.notifier_start_mock.assert_not_called()      def test_remove_channel_removes_channel(self): -        """Channel in FirstHash is removed from `_silenced_channels`.""" +        """Channel is removed from `_silenced_channels`."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):              with self.subTest(current_loop=current_loop):                  with mock.patch.object(self.notifier, "_current_loop", new=current_loop):                      await self.notifier._notifier() -                self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") +                self.alert_channel.send.assert_called_once_with( +                    f"<@&{Roles.moderators}> currently silenced channels: " +                )              self.alert_channel.send.reset_mock()      async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):                      self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the general functionality of the Silence cog.""" + +    @autospec(silence, "Scheduler", pass_mocks=False)      def setUp(self) -> None:          self.bot = MockBot() -        self.cog = Silence(self.bot) -        self.ctx = MockContext() -        self.cog._verified_role = None -        # Set event so command callbacks can continue. -        self.cog._get_instance_vars_event.set() +        self.cog = silence.Silence(self.bot) -    async def test_instance_vars_got_guild(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog._get_instance_vars() -        self.bot.wait_until_guild_available.assert_called_once() +        await self.cog._async_init() +        self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id) -    async def test_instance_vars_got_role(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_role(self):          """Got `Roles.verified` role from guild.""" -        await self.cog._get_instance_vars()          guild = self.bot.get_guild() -        guild.get_role.assert_called_once_with(Roles.verified) +        guild.get_role.side_effect = lambda id_: Mock(id=id_) -    async def test_instance_vars_got_channels(self): +        await self.cog._async_init() +        self.assertEqual(self.cog._verified_role.id, Roles.verified) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_channels(self):          """Got channels from bot.""" -        await self.cog._get_instance_vars() -        self.bot.get_channel.called_once_with(Channels.mod_alerts) -        self.bot.get_channel.called_once_with(Channels.mod_log) +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) -    @mock.patch("bot.exts.moderation.silence.SilenceNotifier") -    async def test_instance_vars_got_notifier(self, notifier): +    @autospec(silence, "SilenceNotifier") +    async def test_async_init_got_notifier(self, notifier):          """Notifier was started with channel.""" -        mod_log = MockTextChannel() -        self.bot.get_channel.side_effect = (None, mod_log) -        await self.cog._get_instance_vars() -        notifier.assert_called_once_with(mod_log) -        self.bot.get_channel.side_effect = None - -    async def test_silence_sent_correct_discord_message(self): -        """Check if proper message was sent when called with duration in channel with previous state.""" +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) +        self.assertEqual(self.cog.notifier, notifier.return_value) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_rescheduled(self): +        """`_reschedule_` coroutine was awaited.""" +        self.cog._reschedule = mock.create_autospec(self.cog._reschedule) +        await self.cog._async_init() +        self.cog._reschedule.assert_awaited_once_with() + +    def test_cog_unload_cancelled_tasks(self): +        """The init task was cancelled.""" +        self.cog._init_task = asyncio.Future() +        self.cog.cog_unload() + +        # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. +        self.assertTrue(self.cog._init_task.cancelled()) + +    @autospec("discord.ext.commands", "has_any_role") +    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    async def test_cog_check(self, role_check): +        """Role check was called with `MODERATION_ROLES`""" +        ctx = MockContext() +        role_check.return_value.predicate = mock.AsyncMock() + +        await self.cog.cog_check(ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the rescheduling of cached unsilences.""" + +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self): +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + +        with mock.patch.object(self.cog, "_reschedule", autospec=True): +            asyncio.run(self.cog._async_init())  # Populate instance attributes. + +    async def test_skipped_missing_channel(self): +        """Did nothing because the channel couldn't be retrieved.""" +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] +        self.bot.get_channel.return_value = None + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_added_permanent_to_notifier(self): +        """Permanently silenced channels were added to the notifier.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_any_call(channels[0]) +        self.cog.notifier.add_channel.assert_any_call(channels[1]) + +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_unsilenced_expired(self): +        """Unsilenced expired silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + +        await self.cog._reschedule() + +        self.cog._unsilence_wrapper.assert_any_call(channels[0]) +        self.cog._unsilence_wrapper.assert_any_call(channels[1]) + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    @mock.patch.object(silence, "datetime", new=PatchedDatetime) +    async def test_rescheduled_active(self): +        """Rescheduled active silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] +        silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + +        self.cog._unsilence_wrapper = mock.MagicMock() +        unsilence_return = self.cog._unsilence_wrapper.return_value + +        await self.cog._reschedule() + +        # Yuck. +        calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] +        self.cog.scheduler.schedule_later.assert_has_calls(calls) + +        unsilence_calls = [mock.call(channel) for channel in channels] +        self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + +        self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the silence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        # Avoid unawaited coroutine warnings. +        self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command."""          test_cases = ( -            (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), -            (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), -            (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), +            (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), +            (None, silence.MSG_SILENCE_PERMANENT, True,), +            (5, silence.MSG_SILENCE_FAIL, False,),          ) -        for duration, result_message, _silence_patch_return in test_cases: -            with self.subTest( -                silence_duration=duration, -                result_message=result_message, -                starting_unsilenced_state=_silence_patch_return -            ): -                with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): -                    await self.cog.silence(self.cog, self.ctx, duration) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_unsilence_sent_correct_discord_message(self): -        """Check if proper message was sent when unsilencing channel.""" -        test_cases = ( -            (True, f"{Emojis.check_mark} unsilenced current channel."), -            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        for duration, message, was_silenced in test_cases: +            ctx = MockContext() +            with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): +                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): +                    await self.cog.silence.callback(self.cog, ctx, duration) +                    ctx.send.assert_called_once_with(message) + +    async def test_skipped_already_silenced(self): +        """Permissions were not set and `False` was returned for an already silenced channel.""" +        subtests = ( +            (False, PermissionOverwrite(send_messages=False, add_reactions=False)), +            (True, PermissionOverwrite(send_messages=True, add_reactions=True)), +            (True, PermissionOverwrite(send_messages=False, add_reactions=False)),          ) -        for _unsilence_patch_return, result_message in test_cases: -            with self.subTest( -                starting_silenced_state=_unsilence_patch_return, -                result_message=result_message -            ): -                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): -                    await self.cog.unsilence(self.cog, self.ctx) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_silence_private_for_false(self): -        """Permissions are not set and `False` is returned in an already silenced channel.""" -        perm_overwrite = Mock(send_messages=False) -        channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - -        self.assertFalse(await self.cog._silence(channel, True, None)) -        channel.set_permissions.assert_not_called() -    async def test_silence_private_silenced_channel(self): -        """Channel had `send_message` permissions revoked.""" -        channel = MockTextChannel() -        self.assertTrue(await self.cog._silence(channel, False, None)) -        channel.set_permissions.assert_called_once() -        self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) +        for contains, overwrite in subtests: +            with self.subTest(contains=contains, overwrite=overwrite): +                self.cog.scheduler.__contains__.return_value = contains +                channel = MockTextChannel() +                channel.overwrites_for.return_value = overwrite + +                self.assertFalse(await self.cog._set_silence_overwrites(channel)) +                channel.set_permissions.assert_not_called() + +    async def test_silenced_channel(self): +        """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) +        self.assertFalse(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite +        ) -    async def test_silence_private_preserves_permissions(self): -        """Previous permissions were preserved when channel was silenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite() -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._silence(channel, False, None) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    async def test_silence_private_notifier(self): -        """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" -        channel = MockTextChannel() -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=True): -                await self.cog._silence(channel, True, None) -                self.cog.notifier.add_channel.assert_called_once() - -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=False): -                await self.cog._silence(channel, False, None) -                self.cog.notifier.add_channel.assert_not_called() - -    async def test_silence_private_added_muted_channel(self): -        """Channel was added to `muted_channels` on silence.""" +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed.""" +        prev_overwrite_dict = dict(self.overwrite) +        await self.cog._set_silence_overwrites(self.channel) +        new_overwrite_dict = dict(self.overwrite) + +        # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. +        del prev_overwrite_dict['send_messages'] +        del prev_overwrite_dict['add_reactions'] +        del new_overwrite_dict['send_messages'] +        del new_overwrite_dict['add_reactions'] + +        self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_temp_not_added_to_notifier(self): +        """Channel was not added to notifier if a duration was set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_indefinite_added_to_notifier(self): +        """Channel was added to notifier if a duration was not set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), None) +            self.cog.notifier.add_channel.assert_called_once() + +    async def test_silenced_not_added_to_notifier(self): +        """Channel was not added to the notifier if it was already silenced.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_cached_previous_overwrites(self): +        """Channel's previous overwrites were cached.""" +        overwrite_json = '{"send_messages": true, "add_reactions": false}' +        await self.cog._set_silence_overwrites(self.channel) +        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + +    @autospec(silence, "datetime") +    async def test_cached_unsilence_time(self, datetime_mock): +        """The UTC POSIX timestamp for the unsilence was cached.""" +        now_timestamp = 100 +        duration = 15 +        timestamp = now_timestamp + duration * 60 +        datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, duration) + +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) +        datetime_mock.now.assert_called_once_with(tz=timezone.utc)  # Ensure it's using an aware dt. + +    async def test_cached_indefinite_time(self): +        """A value of -1 was cached for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + +    async def test_scheduled_task(self): +        """An unsilence task was scheduled.""" +        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + +        await self.cog.silence.callback(self.cog, ctx, 5) + +        args = (300, ctx.channel.id, ctx.invoke.return_value) +        self.cog.scheduler.schedule_later.assert_called_once_with(*args) +        ctx.invoke.assert_called_once_with(self.cog.unsilence) + +    async def test_permanent_not_scheduled(self): +        """A task was not scheduled for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the unsilence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.cog.scheduler.__contains__.return_value = True +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command.""" +        unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) +        test_cases = ( +            (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), +        ) +        for was_unsilenced, message, overwrite in test_cases: +            ctx = MockContext() +            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): +                with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): +                    ctx.channel.overwrites_for.return_value = overwrite +                    await self.cog.unsilence.callback(self.cog, ctx) +                    ctx.channel.send.assert_called_once_with(message) + +    async def test_skipped_already_unsilenced(self): +        """Permissions were not set and `False` was returned for an already unsilenced channel.""" +        self.cog.scheduler.__contains__.return_value = False +        self.cog.previous_overwrites.get.return_value = None          channel = MockTextChannel() -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._silence(channel, False, None) -        muted_channels.add.assert_called_once_with(channel) -    async def test_unsilence_private_for_false(self): -        """Permissions are not set and `False` is returned in an unsilenced channel.""" -        channel = Mock()          self.assertFalse(await self.cog._unsilence(channel))          channel.set_permissions.assert_not_called() -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_unsilenced_channel(self, _): -        """Channel had `send_message` permissions restored""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        self.assertTrue(await self.cog._unsilence(channel)) -        channel.set_permissions.assert_called_once() -        self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_notifier(self, notifier): -        """Channel was removed from `notifier` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        await self.cog._unsilence(channel) -        notifier.remove_channel.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_muted_channel(self, _): -        """Channel was removed from `muted_channels` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._unsilence(channel) -        muted_channels.discard.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_preserves_permissions(self, _): -        """Previous permissions were preserved when channel was unsilenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite(send_messages=False) -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._unsilence(channel) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    @mock.patch.object(Silence, "_mod_alerts_channel", create=True) -    def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): -        """Task for sending an alert was created with present `muted_channels`.""" -        with mock.patch.object(self.cog, "muted_channels"): -            self.cog.cog_unload() -            alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") -            asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    def test_cog_unload_skips_task_start(self, asyncio_mock): -        """No task created with no channels.""" -        self.cog.cog_unload() -        asyncio_mock.create_task.assert_not_called() +    async def test_restored_overwrites(self): +        """Channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) -    @mock.patch("discord.ext.commands.has_any_role") -    @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) -    async def test_cog_check(self, role_check): -        """Role check is called with `MODERATION_ROLES`""" -        role_check.return_value.predicate = mock.AsyncMock() -        await self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(*(1, 2, 3)) -        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) +        # Recall that these values are determined by the fixture. +        self.assertTrue(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) + +    async def test_cache_miss_used_default_overwrites(self): +        """Both overwrites were set to None due previous values not being found in the cache.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) + +        self.assertIsNone(self.overwrite.send_messages) +        self.assertIsNone(self.overwrite.add_reactions) + +    async def test_cache_miss_sent_mod_alert(self): +        """A message was sent to the mod alerts channel.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.cog._mod_alerts_channel.send.assert_awaited_once() + +    async def test_removed_notifier(self): +        """Channel was removed from `notifier`.""" +        await self.cog._unsilence(self.channel) +        self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + +    async def test_deleted_cached_overwrite(self): +        """Channel was deleted from the overwrites cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + +    async def test_deleted_cached_time(self): +        """Channel was deleted from the timestamp cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + +    async def test_cancelled_task(self): +        """The scheduled unsilence task should be cancelled.""" +        await self.cog._unsilence(self.channel) +        self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed, including cache misses.""" +        for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): +            with self.subTest(overwrite_json=overwrite_json): +                self.cog.previous_overwrites.get.return_value = overwrite_json + +                prev_overwrite_dict = dict(self.overwrite) +                await self.cog._unsilence(self.channel) +                new_overwrite_dict = dict(self.overwrite) + +                # Remove these keys because they were modified by the unsilence. +                del prev_overwrite_dict['send_messages'] +                del prev_overwrite_dict['add_reactions'] +                del new_overwrite_dict['send_messages'] +                del new_overwrite_dict['add_reactions'] + +                self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 6601fad2c..9a42d0610 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -52,6 +52,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),              ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),              ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), +            ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), +            ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', +             'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), +            ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', +             'print("How\'s it going?")', 'code block preceded by inline code'), +            ('`print("Hello world!")`\ntext\n`print("Hello world!")`', +             'print("Hello world!")', 'one inline code block of two')          )          for case, expected, testname in cases:              with self.subTest(msg=f'Extract code from {testname}.'): diff --git a/tests/helpers.py b/tests/helpers.py index e47fdf28f..870f66197 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools  import logging  import unittest.mock  from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional  import discord  from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context  from bot.api import APIClient  from bot.async_stats import AsyncStatsClient  from bot.bot import Bot +from tests._autospec import autospec  # noqa: F401 other modules import it via this module  for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: -    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" -    # Caller's kwargs should take priority and overwrite the defaults. -    kwargs = {'spec_set': True, 'autospec': True, **kwargs} - -    # Import the target if it's a string. -    # This is to support both object and string targets like patch.multiple. -    if type(target) is str: -        target = unittest.mock._importer(target) - -    def decorator(func): -        for attribute in attributes: -            patcher = unittest.mock.patch.object(target, attribute, **kwargs) -            func = patcher(func) -        return func -    return decorator - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. | 
