diff options
| author | 2020-11-21 03:04:41 +0200 | |
|---|---|---|
| committer | 2020-11-21 03:04:41 +0200 | |
| commit | c3a927569da782c24299c8ae75df28ae6cd3f2ba (patch) | |
| tree | 51f72457a0e7c8d97286de12aa2d594206d698b2 /tests | |
| parent | Make `additional_info` non-optional. (diff) | |
| parent | Merge pull request #1287 from python-discord/help-channel-msg (diff) | |
Merge branch 'master' into superstar-fix
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/_autospec.py | 64 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_base.py | 29 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_cog.py | 38 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_roles.py | 26 | ||||
| -rw-r--r-- | tests/bot/exts/backend/sync/test_users.py | 147 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_information.py | 175 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 148 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 587 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_snekbox.py | 25 | ||||
| -rw-r--r-- | tests/bot/patches/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 29 | ||||
| -rw-r--r-- | tests/bot/utils/test_services.py | 39 | ||||
| -rw-r--r-- | tests/helpers.py | 21 | 
13 files changed, 884 insertions, 444 deletions
| diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): +    """Skips adding patchings as args if their `dont_pass` attribute is True.""" +    # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. +    extra_args = [] +    with contextlib.ExitStack() as exit_stack: +        for patching in patched.patchings: +            arg = exit_stack.enter_context(patching) +            if not getattr(patching, "dont_pass", False): +                # Only add the patching as an arg if dont_pass is False. +                if patching.attribute_name is not None: +                    keywargs.update(arg) +                elif patching.new is unittest.mock.DEFAULT: +                    extra_args.append(arg) + +        args += tuple(extra_args) +        yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): +    """Copy the `dont_pass` attribute along with the standard copy operation.""" +    patcher_copy = _copy.original(self) +    patcher_copy.dont_pass = getattr(self, "dont_pass", False) +    return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: +    """ +    Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + +    If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. +    """ +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = dict(spec_set=True, autospec=True) +    kwargs.update(patch_kwargs) + +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            if not pass_mocks: +                # A custom attribute to keep track of which patchings should be skipped. +                patcher.dont_pass = True +            func = patcher(func) +        return func +    return decorator diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 4953550f9..3ad9db9c3 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -15,28 +15,21 @@ class TestSyncer(Syncer):      _sync = mock.AsyncMock() -class SyncerBaseTests(unittest.TestCase): -    """Tests for the syncer base class.""" - -    def setUp(self): -        self.bot = helpers.MockBot() - -    def test_instantiation_fails_without_abstract_methods(self): -        """The class must have abstract methods implemented.""" -        with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): -            Syncer(self.bot) - -  class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for main function orchestrating the sync."""      def setUp(self): -        self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) -        self.syncer = TestSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True))) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop) +          self.guild = helpers.MockGuild() +        TestSyncer._get_diff.reset_mock(return_value=True, side_effect=True) +        TestSyncer._sync.reset_mock(return_value=True, side_effect=True) +          # Make sure `_get_diff` returns a MagicMock, not an AsyncMock -        self.syncer._get_diff.return_value = mock.MagicMock() +        TestSyncer._get_diff.return_value = mock.MagicMock()      async def test_sync_message_edited(self):          """The message should be edited if one was sent, even if the sync has an API error.""" @@ -48,11 +41,11 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):          for message, side_effect, should_edit in subtests:              with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): -                self.syncer._sync.side_effect = side_effect +                TestSyncer._sync.side_effect = side_effect                  ctx = helpers.MockContext()                  ctx.send.return_value = message -                await self.syncer.sync(self.guild, ctx) +                await TestSyncer.sync(self.guild, ctx)                  if should_edit:                      message.edit.assert_called_once() @@ -67,7 +60,7 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):          for ctx, message in subtests:              with self.subTest(ctx=ctx, message=message): -                await self.syncer.sync(self.guild, ctx) +                await TestSyncer.sync(self.guild, ctx)                  if ctx is not None:                      ctx.send.assert_called_once() diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 1b89564f2..22a07313e 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):      def setUp(self):          self.bot = helpers.MockBot() -        self.role_syncer_patcher = mock.patch( +        role_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.RoleSyncer",              autospec=Syncer,              spec_set=True          ) -        self.user_syncer_patcher = mock.patch( +        user_syncer_patcher = mock.patch(              "bot.exts.backend.sync._syncers.UserSyncer",              autospec=Syncer,              spec_set=True          ) -        self.RoleSyncer = self.role_syncer_patcher.start() -        self.UserSyncer = self.user_syncer_patcher.start() -        self.cog = Sync(self.bot) +        self.RoleSyncer = role_syncer_patcher.start() +        self.UserSyncer = user_syncer_patcher.start() -    def tearDown(self): -        self.role_syncer_patcher.stop() -        self.user_syncer_patcher.stop() +        self.addCleanup(role_syncer_patcher.stop) +        self.addCleanup(user_syncer_patcher.stop) + +        self.cog = Sync(self.bot)      @staticmethod      def response_error(status: int) -> ResponseCodeError: @@ -73,8 +73,6 @@ class SyncCogTests(SyncCogTestCase):          Sync(self.bot) -        self.RoleSyncer.assert_called_once_with(self.bot) -        self.UserSyncer.assert_called_once_with(self.bot)          sync_guild.assert_called_once_with()          self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) @@ -83,8 +81,8 @@ class SyncCogTests(SyncCogTestCase):          for guild in (helpers.MockGuild(), None):              with self.subTest(guild=guild):                  self.bot.reset_mock() -                self.cog.role_syncer.reset_mock() -                self.cog.user_syncer.reset_mock() +                self.RoleSyncer.reset_mock() +                self.UserSyncer.reset_mock()                  self.bot.get_guild = mock.MagicMock(return_value=guild) @@ -94,11 +92,11 @@ class SyncCogTests(SyncCogTestCase):                  self.bot.get_guild.assert_called_once_with(constants.Guild.id)                  if guild is None: -                    self.cog.role_syncer.sync.assert_not_called() -                    self.cog.user_syncer.sync.assert_not_called() +                    self.RoleSyncer.sync.assert_not_called() +                    self.UserSyncer.sync.assert_not_called()                  else: -                    self.cog.role_syncer.sync.assert_called_once_with(guild) -                    self.cog.user_syncer.sync.assert_called_once_with(guild) +                    self.RoleSyncer.sync.assert_called_once_with(guild) +                    self.UserSyncer.sync.assert_called_once_with(guild)      async def patch_user_helper(self, side_effect: BaseException) -> None:          """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" @@ -392,16 +390,16 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):      async def test_sync_roles_command(self):          """sync() should be called on the RoleSyncer."""          ctx = helpers.MockContext() -        await self.cog.sync_roles_command.callback(self.cog, ctx) +        await self.cog.sync_roles_command(self.cog, ctx) -        self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) +        self.RoleSyncer.sync.assert_called_once_with(ctx.guild, ctx)      async def test_sync_users_command(self):          """sync() should be called on the UserSyncer."""          ctx = helpers.MockContext() -        await self.cog.sync_users_command.callback(self.cog, ctx) +        await self.cog.sync_users_command(self.cog, ctx) -        self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) +        self.UserSyncer.sync.assert_called_once_with(ctx.guild, ctx)      async def test_commands_require_admin(self):          """The sync commands should only run if the author has the administrator permission.""" diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 7b9f40cad..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -22,8 +22,9 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between roles in the DB and roles in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @staticmethod      def get_guild(*roles): @@ -44,7 +45,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role()]          guild = self.get_guild(fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), set(), set())          self.assertEqual(actual_diff, expected_diff) @@ -56,7 +57,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]          guild = self.get_guild(updated_role, fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), {_Role(**updated_role)}, set())          self.assertEqual(actual_diff, expected_diff) @@ -68,7 +69,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role()]          guild = self.get_guild(fake_role(), new_role) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = ({_Role(**new_role)}, set(), set())          self.assertEqual(actual_diff, expected_diff) @@ -80,7 +81,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          self.bot.api_client.get.return_value = [fake_role(), deleted_role]          guild = self.get_guild(fake_role()) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = (set(), set(), {_Role(**deleted_role)})          self.assertEqual(actual_diff, expected_diff) @@ -98,7 +99,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          ]          guild = self.get_guild(fake_role(), new, updated) -        actual_diff = await self.syncer._get_diff(guild) +        actual_diff = await RoleSyncer._get_diff(guild)          expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})          self.assertEqual(actual_diff, expected_diff) @@ -108,8 +109,9 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync roles."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = RoleSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      async def test_sync_created_roles(self):          """Only POST requests should be made with the correct payload.""" @@ -117,7 +119,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(role_tuples, set(), set()) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call("bot/roles", json=role) for role in roles]          self.bot.api_client.post.assert_has_calls(calls, any_order=True) @@ -132,7 +134,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(set(), role_tuples, set()) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]          self.bot.api_client.put.assert_has_calls(calls, any_order=True) @@ -147,7 +149,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          role_tuples = {_Role(**role) for role in roles}          diff = _Diff(set(), set(), role_tuples) -        await self.syncer._sync(diff) +        await RoleSyncer._sync(diff)          calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]          self.bot.api_client.delete.assert_has_calls(calls, any_order=True) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index c0a1da35c..61673e1bb 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,7 @@  import unittest  from unittest import mock -from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff  from tests import helpers @@ -10,7 +10,7 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("roles", [666])      kwargs.setdefault("in_guild", True)      return kwargs @@ -20,8 +20,9 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):      """Tests for determining differences between users in the DB and users in the Guild cache."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @staticmethod      def get_guild(*members): @@ -40,22 +41,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          return guild +    @staticmethod +    def get_mock_member(member: dict): +        member = member.copy() +        del member["in_guild"] +        mock_member = helpers.MockMember(**member) +        mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] +        return mock_member +      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [] +        }          guild = self.get_guild() -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical.""" -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user()) -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        guild.get_member.return_value = self.get_mock_member(fake_user()) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -63,59 +84,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(id=99, name="old"), fake_user()] +        }          guild = self.get_guild(updated_user, fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(updated_user), +            self.get_mock_member(fake_user()) +        ] -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**updated_user)}, None) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([], [{"id": 99, "name": "new"}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_users(self): -        """Only new users should be added to the 'created' set of the diff.""" +        """Only new users should be added to the 'created' list of the diff."""          new_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user(), new_user) - -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, set(), None) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(new_user) +        ] +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([new_user], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        leaving_user = fake_user(id=63, in_guild=False) - -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ] -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**leaving_user)}, None) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([], [{"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") +          updated_user = fake_user(id=55, name="updated") -        leaving_user = fake_user(id=63, in_guild=False) -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=55), fake_user(id=63)] +        }          guild = self.get_guild(fake_user(), new_user, updated_user) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(updated_user), +            None +        ] -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_db_users_not_in_guild(self): -        """When the DB knows a user the guild doesn't, no difference is found.""" -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        """When the DB knows a user, but the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63, in_guild=False)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ] -        actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        actual_diff = await UserSyncer._get_diff(guild) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -124,20 +188,18 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):      """Tests for the API requests that sync users."""      def setUp(self): -        self.bot = helpers.MockBot() -        self.syncer = UserSyncer(self.bot) +        patcher = mock.patch("bot.instance", new=helpers.MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      async def test_sync_created_users(self):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(user_tuples, set(), None) -        await self.syncer._sync(diff) +        diff = _Diff(users, [], None) +        await UserSyncer._sync(diff) -        calls = [mock.call("bot/users", json=user) for user in users] -        self.bot.api_client.post.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.post.call_count, len(users)) +        self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() @@ -146,13 +208,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          """Only PUT requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(set(), user_tuples, None) -        await self.syncer._sync(diff) +        diff = _Diff([], users, None) +        await UserSyncer._sync(diff) -        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] -        self.bot.api_client.put.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.put.call_count, len(users)) +        self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index d3f2995fb..daede54c5 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -1,4 +1,3 @@ -import asyncio  import textwrap  import unittest  import unittest.mock @@ -13,7 +12,7 @@ from tests import helpers  COG_PATH = "bot.exts.info.information.Information" -class InformationCogTests(unittest.TestCase): +class InformationCogTests(unittest.IsolatedAsyncioTestCase):      """Tests the Information cog."""      @classmethod @@ -29,16 +28,14 @@ class InformationCogTests(unittest.TestCase):          self.ctx = helpers.MockContext()          self.ctx.author.roles.append(self.moderator_role) -    def test_roles_command_command(self): +    async def test_roles_command_command(self):          """Test if the `role_info` command correctly returns the `moderator_role`."""          self.ctx.guild.roles.append(self.moderator_role)          self.cog.roles_info.can_run = unittest.mock.AsyncMock()          self.cog.roles_info.can_run.return_value = True -        coroutine = self.cog.roles_info.callback(self.cog, self.ctx) - -        self.assertIsNone(asyncio.run(coroutine)) +        self.assertIsNone(await self.cog.roles_info(self.cog, self.ctx))          self.ctx.send.assert_called_once()          _, kwargs = self.ctx.send.call_args @@ -48,7 +45,7 @@ class InformationCogTests(unittest.TestCase):          self.assertEqual(embed.colour, discord.Colour.blurple())          self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") -    def test_role_info_command(self): +    async def test_role_info_command(self):          """Tests the `role info` command."""          dummy_role = helpers.MockRole(              name="Dummy", @@ -73,9 +70,7 @@ class InformationCogTests(unittest.TestCase):          self.cog.role_info.can_run = unittest.mock.AsyncMock()          self.cog.role_info.can_run.return_value = True -        coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) - -        self.assertIsNone(asyncio.run(coroutine)) +        self.assertIsNone(await self.cog.role_info(self.cog, self.ctx, dummy_role, admin_role))          self.assertEqual(self.ctx.send.call_count, 2) @@ -97,80 +92,8 @@ class InformationCogTests(unittest.TestCase):          self.assertEqual(admin_embed.title, "Admins info")          self.assertEqual(admin_embed.colour, discord.Colour.red()) -    @unittest.mock.patch('bot.exts.info.information.time_since') -    def test_server_info_command(self, time_since_patch): -        time_since_patch.return_value = '2 days ago' - -        self.ctx.guild = helpers.MockGuild( -            features=('lemons', 'apples'), -            region="The Moon", -            roles=[self.moderator_role], -            channels=[ -                discord.TextChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} -                ), -                discord.CategoryChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} -                ), -                discord.VoiceChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} -                ) -            ], -            members=[ -                *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), -                *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), -                *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), -                *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), -            ], -            member_count=1_234, -            icon_url='a-lemon.jpg', -        ) - -        coroutine = self.cog.server_info.callback(self.cog, self.ctx) -        self.assertIsNone(asyncio.run(coroutine)) - -        time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') -        _, kwargs = self.ctx.send.call_args -        embed = kwargs.pop('embed') -        self.assertEqual(embed.colour, discord.Colour.blurple()) -        self.assertEqual( -            embed.description, -            textwrap.dedent( -                f""" -                **Server information** -                Created: {time_since_patch.return_value} -                Voice region: {self.ctx.guild.region} -                Features: {', '.join(self.ctx.guild.features)} - -                **Channel counts** -                Category channels: 1 -                Text channels: 1 -                Voice channels: 1 -                Staff channels: 0 - -                **Member counts** -                Members: {self.ctx.guild.member_count:,} -                Staff members: 0 -                Roles: {len(self.ctx.guild.roles)} - -                **Member statuses** -                {constants.Emojis.status_online} 2 -                {constants.Emojis.status_idle} 1 -                {constants.Emojis.status_dnd} 4 -                {constants.Emojis.status_offline} 3 -                """ -            ) -        ) -        self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') - -class UserInfractionHelperMethodTests(unittest.TestCase): +class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):      """Tests for the helper methods of the `!user` command."""      def setUp(self): @@ -180,7 +103,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          self.cog = information.Information(self.bot)          self.member = helpers.MockMember(id=1234) -    def test_user_command_helper_method_get_requests(self): +    async def test_user_command_helper_method_get_requests(self):          """The helper methods should form the correct get requests."""          test_values = (              { @@ -202,11 +125,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):              endpoint, params = test_value["expected_args"]              with self.subTest(method=helper_method, endpoint=endpoint, params=params): -                asyncio.run(helper_method(self.member)) +                await helper_method(self.member)                  self.bot.api_client.get.assert_called_once_with(endpoint, params=params)                  self.bot.api_client.get.reset_mock() -    def _method_subtests(self, method, test_values, default_header): +    async def _method_subtests(self, method, test_values, default_header):          """Helper method that runs the subtests for the different helper methods."""          for test_value in test_values:              api_response = test_value["api response"] @@ -216,11 +139,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):                  self.bot.api_client.get.return_value = api_response                  expected_output = "\n".join(expected_lines) -                actual_output = asyncio.run(method(self.member)) +                actual_output = await method(self.member)                  self.assertEqual((default_header, expected_output), actual_output) -    def test_basic_user_infraction_counts_returns_correct_strings(self): +    async def test_basic_user_infraction_counts_returns_correct_strings(self):          """The method should correctly list both the total and active number of non-hidden infractions."""          test_values = (              # No infractions means zero counts @@ -251,9 +174,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          header = "Infractions" -        self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) +        await self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) -    def test_expanded_user_infraction_counts_returns_correct_strings(self): +    async def test_expanded_user_infraction_counts_returns_correct_strings(self):          """The method should correctly list the total and active number of all infractions split by infraction type."""          test_values = (              { @@ -306,9 +229,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          header = "Infractions" -        self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) +        await self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) -    def test_user_nomination_counts_returns_correct_strings(self): +    async def test_user_nomination_counts_returns_correct_strings(self):          """The method should list the number of active and historical nominations for the user."""          test_values = (              { @@ -336,12 +259,12 @@ class UserInfractionHelperMethodTests(unittest.TestCase):          header = "Nominations" -        self._method_subtests(self.cog.user_nomination_counts, test_values, header) +        await self._method_subtests(self.cog.user_nomination_counts, test_values, header)  @unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))  @unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50]) -class UserEmbedTests(unittest.TestCase): +class UserEmbedTests(unittest.IsolatedAsyncioTestCase):      """Tests for the creation of the `!user` embed."""      def setUp(self): @@ -354,14 +277,14 @@ class UserEmbedTests(unittest.TestCase):          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): +    async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember()          user.nick = None          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          self.assertEqual(embed.title, "Mr. Hemlock") @@ -369,14 +292,14 @@ class UserEmbedTests(unittest.TestCase):          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_uses_nick_in_title_if_available(self): +    async def test_create_user_embed_uses_nick_in_title_if_available(self):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember()          user.nick = "Cat lover"          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock") -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") @@ -384,7 +307,7 @@ class UserEmbedTests(unittest.TestCase):          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_ignores_everyone_role(self): +    async def test_create_user_embed_ignores_everyone_role(self):          """Created `!user` embeds should not contain mention of the @everyone-role."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          admins_role = helpers.MockRole(name='Admins') @@ -393,14 +316,18 @@ class UserEmbedTests(unittest.TestCase):          # A `MockMember` has the @Everyone role by default; we add the Admins to that.          user = helpers.MockMember(roles=[admins_role], top_role=admins_role) -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          self.assertIn("&Admins", embed.fields[1].value)          self.assertNotIn("&Everyone", embed.fields[1].value)      @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)      @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) -    def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): +    async def test_create_user_embed_expanded_information_in_moderation_channels( +            self, +            nomination_counts, +            infraction_counts +    ):          """The embed should contain expanded infractions and nomination info in mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -411,7 +338,7 @@ class UserEmbedTests(unittest.TestCase):          nomination_counts.return_value = ("Nominations", "nomination info")          user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          infraction_counts.assert_called_once_with(user)          nomination_counts.assert_called_once_with(user) @@ -434,7 +361,7 @@ class UserEmbedTests(unittest.TestCase):          )      @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) -    def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): +    async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):          """The embed should contain only basic infraction data outside of mod channels."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -444,7 +371,7 @@ class UserEmbedTests(unittest.TestCase):          infraction_counts.return_value = ("Infractions", "basic infractions info")          user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          infraction_counts.assert_called_once_with(user) @@ -467,14 +394,14 @@ class UserEmbedTests(unittest.TestCase):          self.assertEqual(              "basic infractions info", -            embed.fields[3].value +            embed.fields[2].value          )      @unittest.mock.patch(          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): +    async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):          """The embed should be created with the colour of the top role, if a top role is available."""          ctx = helpers.MockContext() @@ -482,7 +409,7 @@ class UserEmbedTests(unittest.TestCase):          moderators_role.colour = 100          user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) @@ -490,12 +417,12 @@ class UserEmbedTests(unittest.TestCase):          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): +    async def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):          """The embed should be created with a blurple colour if the user has no assigned roles."""          ctx = helpers.MockContext()          user = helpers.MockMember(id=217) -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          self.assertEqual(embed.colour, discord.Colour.blurple()) @@ -503,20 +430,20 @@ class UserEmbedTests(unittest.TestCase):          f"{COG_PATH}.basic_user_infraction_counts",          new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))      ) -    def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): +    async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):          """The embed thumbnail should be set to the user's avatar in `png` format."""          ctx = helpers.MockContext()          user = helpers.MockMember(id=217)          user.avatar_url_as.return_value = "avatar url" -        embed = asyncio.run(self.cog.create_user_embed(ctx, user)) +        embed = await self.cog.create_user_embed(ctx, user)          user.avatar_url_as.assert_called_once_with(static_format="png")          self.assertEqual(embed.thumbnail.url, "avatar url")  @unittest.mock.patch("bot.exts.info.information.constants") -class UserCommandTests(unittest.TestCase): +class UserCommandTests(unittest.IsolatedAsyncioTestCase):      """Tests for the `!user` command."""      def setUp(self): @@ -536,16 +463,16 @@ class UserCommandTests(unittest.TestCase):          # used as a default value for a parameter, which gets defined upon import.          self.bot_command_channel = helpers.MockTextChannel(id=constants.Channels.bot_commands) -    def test_regular_member_cannot_target_another_member(self, constants): +    async def test_regular_member_cannot_target_another_member(self, constants):          """A regular user should not be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.author) -        asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) +        await self.cog.user_info(self.cog, ctx, self.target)          ctx.send.assert_called_once_with("You may not use this command on users other than yourself.") -    def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants): +    async def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):          """A regular user should not be able to use this command outside of bot-commands."""          constants.MODERATION_ROLES = [self.moderator_role.id]          constants.STAFF_ROLES = [self.moderator_role.id] @@ -553,49 +480,49 @@ class UserCommandTests(unittest.TestCase):          msg = "Sorry, but you may only use this command within <#50>."          with self.assertRaises(InWhitelistCheckFailure, msg=msg): -            asyncio.run(self.cog.user_info.callback(self.cog, ctx)) +            await self.cog.user_info(self.cog, ctx)      @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") -    def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): +    async def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):          """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""          constants.STAFF_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel) -        asyncio.run(self.cog.user_info.callback(self.cog, ctx)) +        await self.cog.user_info(self.cog, ctx)          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once()      @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") -    def test_regular_user_can_explicitly_target_themselves(self, create_embed, _): +    async def test_regular_user_can_explicitly_target_themselves(self, create_embed, _):          """A user should target itself with `!user` when a `user` argument was not provided."""          constants.STAFF_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel) -        asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author)) +        await self.cog.user_info(self.cog, ctx, self.author)          create_embed.assert_called_once_with(ctx, self.author)          ctx.send.assert_called_once()      @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") -    def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): +    async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction."""          constants.STAFF_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) -        asyncio.run(self.cog.user_info.callback(self.cog, ctx)) +        await self.cog.user_info(self.cog, ctx)          create_embed.assert_called_once_with(ctx, self.moderator)          ctx.send.assert_called_once()      @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed") -    def test_moderators_can_target_another_member(self, create_embed, constants): +    async def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id]          constants.STAFF_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50)) -        asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target)) +        await self.cog.user_info(self.cog, ctx, self.target)          create_embed.assert_called_once_with(ctx, self.target)          ctx.send.assert_called_once() diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index be1b649e1..bf557a484 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,7 +1,8 @@  import textwrap  import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from bot.constants import Event  from bot.exts.moderation.infraction.infractions import Infractions  from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value          ) + + +@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) +class VoiceBanTests(unittest.IsolatedAsyncioTestCase): +    """Tests for voice ban related functions and commands.""" + +    def setUp(self): +        self.bot = MockBot() +        self.mod = MockMember(top_role=10) +        self.user = MockMember(top_role=1, roles=[MockRole(id=123456)]) +        self.guild = MockGuild() +        self.ctx = MockContext(bot=self.bot, author=self.mod) +        self.cog = Infractions(self.bot) + +    async def test_permanent_voice_ban(self): +        """Should call voice ban applying function without expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar") + +    async def test_temporary_voice_ban(self): +        """Should call voice ban applying function with expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + +    async def test_voice_unban(self): +        """Should call infraction pardoning function.""" +        self.cog.pardon_infraction = AsyncMock() +        self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) +        self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): +        """Should return early when user already have Voice Ban infraction.""" +        get_active_infraction.return_value = {"foo": "bar"} +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") +        post_infraction_mock.assert_not_awaited() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): +        """Should return early when posting infraction fails.""" +        self.cog.mod_log.ignore = MagicMock() +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        post_infraction_mock.assert_awaited_once() +        self.cog.mod_log.ignore.assert_not_called() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): +        """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" +        get_active_infraction.return_value = None +        # We don't want that this continue yet +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) +        post_infraction_mock.assert_awaited_once_with( +            self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 +        ) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar") +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): +        """Should truncate reason for voice ban.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) +        self.user.remove_roles.assert_called_once_with( +            self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...") +        ) +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    async def test_voice_unban_user_not_found(self): +        """Should include info to return dict when user was not found from guild.""" +        self.guild.get_member.return_value = None +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, {"Info": "User was not found in the guild."}) + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = True +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "Sent" +        }) +        notify_pardon_mock.assert_awaited_once() + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = False +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "**Failed**" +        }) +        notify_pardon_mock.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index e2d44c637..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio  import unittest +from datetime import datetime, timezone  from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule():  # noqa: N802 +    """Create and connect to the fakeredis session.""" +    global redis_session +    redis_session = RedisSession(use_fakeredis=True) +    redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule():  # noqa: N802 +    """Close the fakeredis session.""" +    if redis_session: +        redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): +    """A datetime object with a mocked now() function.""" + +    now = mock.create_autospec(datetime, "now")  class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.alert_channel = MockTextChannel() -        self.notifier = SilenceNotifier(self.alert_channel) +        self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock()          self.notifier.start = self.notifier_start_mock = Mock()      def test_add_channel_adds_channel(self): -        """Channel in FirstHash with current loop is added to internal set.""" +        """Channel is added to `_silenced_channels` with the current loop."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):          self.notifier_start_mock.assert_not_called()      def test_remove_channel_removes_channel(self): -        """Channel in FirstHash is removed from `_silenced_channels`.""" +        """Channel is removed from `_silenced_channels`."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):              with self.subTest(current_loop=current_loop):                  with mock.patch.object(self.notifier, "_current_loop", new=current_loop):                      await self.notifier._notifier() -                self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") +                self.alert_channel.send.assert_called_once_with( +                    f"<@&{Roles.moderators}> currently silenced channels: " +                )              self.alert_channel.send.reset_mock()      async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):                      self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the general functionality of the Silence cog.""" + +    @autospec(silence, "Scheduler", pass_mocks=False)      def setUp(self) -> None:          self.bot = MockBot() -        self.cog = Silence(self.bot) -        self.ctx = MockContext() -        self.cog._verified_role = None -        # Set event so command callbacks can continue. -        self.cog._get_instance_vars_event.set() +        self.cog = silence.Silence(self.bot) -    async def test_instance_vars_got_guild(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog._get_instance_vars() -        self.bot.wait_until_guild_available.assert_called_once() +        await self.cog._async_init() +        self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id) -    async def test_instance_vars_got_role(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_role(self):          """Got `Roles.verified` role from guild.""" -        await self.cog._get_instance_vars()          guild = self.bot.get_guild() -        guild.get_role.assert_called_once_with(Roles.verified) +        guild.get_role.side_effect = lambda id_: Mock(id=id_) -    async def test_instance_vars_got_channels(self): +        await self.cog._async_init() +        self.assertEqual(self.cog._verified_role.id, Roles.verified) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_channels(self):          """Got channels from bot.""" -        await self.cog._get_instance_vars() -        self.bot.get_channel.called_once_with(Channels.mod_alerts) -        self.bot.get_channel.called_once_with(Channels.mod_log) +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) -    @mock.patch("bot.exts.moderation.silence.SilenceNotifier") -    async def test_instance_vars_got_notifier(self, notifier): +    @autospec(silence, "SilenceNotifier") +    async def test_async_init_got_notifier(self, notifier):          """Notifier was started with channel.""" -        mod_log = MockTextChannel() -        self.bot.get_channel.side_effect = (None, mod_log) -        await self.cog._get_instance_vars() -        notifier.assert_called_once_with(mod_log) -        self.bot.get_channel.side_effect = None - -    async def test_silence_sent_correct_discord_message(self): -        """Check if proper message was sent when called with duration in channel with previous state.""" +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) +        self.assertEqual(self.cog.notifier, notifier.return_value) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_rescheduled(self): +        """`_reschedule_` coroutine was awaited.""" +        self.cog._reschedule = mock.create_autospec(self.cog._reschedule) +        await self.cog._async_init() +        self.cog._reschedule.assert_awaited_once_with() + +    def test_cog_unload_cancelled_tasks(self): +        """The init task was cancelled.""" +        self.cog._init_task = asyncio.Future() +        self.cog.cog_unload() + +        # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. +        self.assertTrue(self.cog._init_task.cancelled()) + +    @autospec("discord.ext.commands", "has_any_role") +    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    async def test_cog_check(self, role_check): +        """Role check was called with `MODERATION_ROLES`""" +        ctx = MockContext() +        role_check.return_value.predicate = mock.AsyncMock() + +        await self.cog.cog_check(ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the rescheduling of cached unsilences.""" + +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self): +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + +        with mock.patch.object(self.cog, "_reschedule", autospec=True): +            asyncio.run(self.cog._async_init())  # Populate instance attributes. + +    async def test_skipped_missing_channel(self): +        """Did nothing because the channel couldn't be retrieved.""" +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] +        self.bot.get_channel.return_value = None + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_added_permanent_to_notifier(self): +        """Permanently silenced channels were added to the notifier.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_any_call(channels[0]) +        self.cog.notifier.add_channel.assert_any_call(channels[1]) + +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_unsilenced_expired(self): +        """Unsilenced expired silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + +        await self.cog._reschedule() + +        self.cog._unsilence_wrapper.assert_any_call(channels[0]) +        self.cog._unsilence_wrapper.assert_any_call(channels[1]) + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    @mock.patch.object(silence, "datetime", new=PatchedDatetime) +    async def test_rescheduled_active(self): +        """Rescheduled active silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] +        silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + +        self.cog._unsilence_wrapper = mock.MagicMock() +        unsilence_return = self.cog._unsilence_wrapper.return_value + +        await self.cog._reschedule() + +        # Yuck. +        calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] +        self.cog.scheduler.schedule_later.assert_has_calls(calls) + +        unsilence_calls = [mock.call(channel) for channel in channels] +        self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + +        self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the silence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        # Avoid unawaited coroutine warnings. +        self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command."""          test_cases = ( -            (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), -            (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), -            (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), +            (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), +            (None, silence.MSG_SILENCE_PERMANENT, True,), +            (5, silence.MSG_SILENCE_FAIL, False,),          ) -        for duration, result_message, _silence_patch_return in test_cases: -            with self.subTest( -                silence_duration=duration, -                result_message=result_message, -                starting_unsilenced_state=_silence_patch_return -            ): -                with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): -                    await self.cog.silence.callback(self.cog, self.ctx, duration) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_unsilence_sent_correct_discord_message(self): -        """Check if proper message was sent when unsilencing channel.""" -        test_cases = ( -            (True, f"{Emojis.check_mark} unsilenced current channel."), -            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        for duration, message, was_silenced in test_cases: +            ctx = MockContext() +            with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): +                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): +                    await self.cog.silence.callback(self.cog, ctx, duration) +                    ctx.send.assert_called_once_with(message) + +    async def test_skipped_already_silenced(self): +        """Permissions were not set and `False` was returned for an already silenced channel.""" +        subtests = ( +            (False, PermissionOverwrite(send_messages=False, add_reactions=False)), +            (True, PermissionOverwrite(send_messages=True, add_reactions=True)), +            (True, PermissionOverwrite(send_messages=False, add_reactions=False)),          ) -        for _unsilence_patch_return, result_message in test_cases: -            with self.subTest( -                starting_silenced_state=_unsilence_patch_return, -                result_message=result_message -            ): -                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): -                    await self.cog.unsilence.callback(self.cog, self.ctx) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_silence_private_for_false(self): -        """Permissions are not set and `False` is returned in an already silenced channel.""" -        perm_overwrite = Mock(send_messages=False) -        channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - -        self.assertFalse(await self.cog._silence(channel, True, None)) -        channel.set_permissions.assert_not_called() -    async def test_silence_private_silenced_channel(self): -        """Channel had `send_message` permissions revoked.""" -        channel = MockTextChannel() -        self.assertTrue(await self.cog._silence(channel, False, None)) -        channel.set_permissions.assert_called_once() -        self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) +        for contains, overwrite in subtests: +            with self.subTest(contains=contains, overwrite=overwrite): +                self.cog.scheduler.__contains__.return_value = contains +                channel = MockTextChannel() +                channel.overwrites_for.return_value = overwrite + +                self.assertFalse(await self.cog._set_silence_overwrites(channel)) +                channel.set_permissions.assert_not_called() + +    async def test_silenced_channel(self): +        """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) +        self.assertFalse(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite +        ) -    async def test_silence_private_preserves_permissions(self): -        """Previous permissions were preserved when channel was silenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite() -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._silence(channel, False, None) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    async def test_silence_private_notifier(self): -        """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" -        channel = MockTextChannel() -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=True): -                await self.cog._silence(channel, True, None) -                self.cog.notifier.add_channel.assert_called_once() - -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=False): -                await self.cog._silence(channel, False, None) -                self.cog.notifier.add_channel.assert_not_called() - -    async def test_silence_private_added_muted_channel(self): -        """Channel was added to `muted_channels` on silence.""" +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed.""" +        prev_overwrite_dict = dict(self.overwrite) +        await self.cog._set_silence_overwrites(self.channel) +        new_overwrite_dict = dict(self.overwrite) + +        # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. +        del prev_overwrite_dict['send_messages'] +        del prev_overwrite_dict['add_reactions'] +        del new_overwrite_dict['send_messages'] +        del new_overwrite_dict['add_reactions'] + +        self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_temp_not_added_to_notifier(self): +        """Channel was not added to notifier if a duration was set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_indefinite_added_to_notifier(self): +        """Channel was added to notifier if a duration was not set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), None) +            self.cog.notifier.add_channel.assert_called_once() + +    async def test_silenced_not_added_to_notifier(self): +        """Channel was not added to the notifier if it was already silenced.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_cached_previous_overwrites(self): +        """Channel's previous overwrites were cached.""" +        overwrite_json = '{"send_messages": true, "add_reactions": false}' +        await self.cog._set_silence_overwrites(self.channel) +        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + +    @autospec(silence, "datetime") +    async def test_cached_unsilence_time(self, datetime_mock): +        """The UTC POSIX timestamp for the unsilence was cached.""" +        now_timestamp = 100 +        duration = 15 +        timestamp = now_timestamp + duration * 60 +        datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, duration) + +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) +        datetime_mock.now.assert_called_once_with(tz=timezone.utc)  # Ensure it's using an aware dt. + +    async def test_cached_indefinite_time(self): +        """A value of -1 was cached for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + +    async def test_scheduled_task(self): +        """An unsilence task was scheduled.""" +        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + +        await self.cog.silence.callback(self.cog, ctx, 5) + +        args = (300, ctx.channel.id, ctx.invoke.return_value) +        self.cog.scheduler.schedule_later.assert_called_once_with(*args) +        ctx.invoke.assert_called_once_with(self.cog.unsilence) + +    async def test_permanent_not_scheduled(self): +        """A task was not scheduled for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the unsilence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.cog.scheduler.__contains__.return_value = True +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command.""" +        unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) +        test_cases = ( +            (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), +        ) +        for was_unsilenced, message, overwrite in test_cases: +            ctx = MockContext() +            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): +                with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): +                    ctx.channel.overwrites_for.return_value = overwrite +                    await self.cog.unsilence.callback(self.cog, ctx) +                    ctx.channel.send.assert_called_once_with(message) + +    async def test_skipped_already_unsilenced(self): +        """Permissions were not set and `False` was returned for an already unsilenced channel.""" +        self.cog.scheduler.__contains__.return_value = False +        self.cog.previous_overwrites.get.return_value = None          channel = MockTextChannel() -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._silence(channel, False, None) -        muted_channels.add.assert_called_once_with(channel) -    async def test_unsilence_private_for_false(self): -        """Permissions are not set and `False` is returned in an unsilenced channel.""" -        channel = Mock()          self.assertFalse(await self.cog._unsilence(channel))          channel.set_permissions.assert_not_called() -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_unsilenced_channel(self, _): -        """Channel had `send_message` permissions restored""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        self.assertTrue(await self.cog._unsilence(channel)) -        channel.set_permissions.assert_called_once() -        self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_notifier(self, notifier): -        """Channel was removed from `notifier` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        await self.cog._unsilence(channel) -        notifier.remove_channel.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_muted_channel(self, _): -        """Channel was removed from `muted_channels` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._unsilence(channel) -        muted_channels.discard.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_preserves_permissions(self, _): -        """Previous permissions were preserved when channel was unsilenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite(send_messages=False) -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._unsilence(channel) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    @mock.patch.object(Silence, "_mod_alerts_channel", create=True) -    def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): -        """Task for sending an alert was created with present `muted_channels`.""" -        with mock.patch.object(self.cog, "muted_channels"): -            self.cog.cog_unload() -            alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") -            asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    def test_cog_unload_skips_task_start(self, asyncio_mock): -        """No task created with no channels.""" -        self.cog.cog_unload() -        asyncio_mock.create_task.assert_not_called() +    async def test_restored_overwrites(self): +        """Channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) -    @mock.patch("discord.ext.commands.has_any_role") -    @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) -    async def test_cog_check(self, role_check): -        """Role check is called with `MODERATION_ROLES`""" -        role_check.return_value.predicate = mock.AsyncMock() -        await self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(*(1, 2, 3)) -        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) +        # Recall that these values are determined by the fixture. +        self.assertTrue(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) + +    async def test_cache_miss_used_default_overwrites(self): +        """Both overwrites were set to None due previous values not being found in the cache.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) + +        self.assertIsNone(self.overwrite.send_messages) +        self.assertIsNone(self.overwrite.add_reactions) + +    async def test_cache_miss_sent_mod_alert(self): +        """A message was sent to the mod alerts channel.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.cog._mod_alerts_channel.send.assert_awaited_once() + +    async def test_removed_notifier(self): +        """Channel was removed from `notifier`.""" +        await self.cog._unsilence(self.channel) +        self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + +    async def test_deleted_cached_overwrite(self): +        """Channel was deleted from the overwrites cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + +    async def test_deleted_cached_time(self): +        """Channel was deleted from the timestamp cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + +    async def test_cancelled_task(self): +        """The scheduled unsilence task should be cancelled.""" +        await self.cog._unsilence(self.channel) +        self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed, including cache misses.""" +        for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): +            with self.subTest(overwrite_json=overwrite_json): +                self.cog.previous_overwrites.get.return_value = overwrite_json + +                prev_overwrite_dict = dict(self.overwrite) +                await self.cog._unsilence(self.channel) +                new_overwrite_dict = dict(self.overwrite) + +                # Remove these keys because they were modified by the unsilence. +                del prev_overwrite_dict['send_messages'] +                del prev_overwrite_dict['add_reactions'] +                del new_overwrite_dict['send_messages'] +                del new_overwrite_dict['add_reactions'] + +                self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 40b2202aa..321a92445 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -42,9 +42,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      async def test_upload_output(self, mock_paste_util):          """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""          await self.cog.upload_output("Test output.") -        mock_paste_util.assert_called_once_with( -            self.bot.http_session, "Test output.", extension="txt" -        ) +        mock_paste_util.assert_called_once_with("Test output.", extension="txt")      def test_prepare_input(self):          cases = ( @@ -52,6 +50,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),              ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),              ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), +            ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), +            ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', +             'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), +            ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', +             'print("How\'s it going?")', 'code block preceded by inline code'), +            ('`print("Hello world!")`\ntext\n`print("Hello world!")`', +             'print("Hello world!")', 'one inline code block of two')          )          for case, expected, testname in cases:              with self.subTest(msg=f'Extract code from {testname}.'): @@ -154,7 +159,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.send_eval = AsyncMock(return_value=response)          self.cog.continue_eval = AsyncMock(return_value=None) -        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')          self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')          self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_once_with(ctx, response) @@ -168,7 +173,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.continue_eval = AsyncMock()          self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) -        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')          self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))          self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')          self.cog.continue_eval.assert_called_with(ctx, response) @@ -180,7 +185,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          ctx.author.mention = '@LemonLemonishBeard#0042'          ctx.send = AsyncMock()          self.cog.jobs = (42,) -        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') +        await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')          ctx.send.assert_called_once_with(              "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"          ) @@ -188,8 +193,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      async def test_eval_command_call_help(self):          """Test if the eval command call the help command if no code is provided."""          ctx = MockContext(command="sentinel") -        await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') -        ctx.send_help.assert_called_once_with("sentinel") +        await self.cog.eval_command(self.cog, ctx=ctx, code='') +        ctx.send_help.assert_called_once_with(ctx.command)      async def test_send_eval(self):          """Test the send_eval function.""" @@ -290,7 +295,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              )          )          ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) -        ctx.message.clear_reactions.assert_called_once() +        ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)          response.delete.assert_called_once()      async def test_continue_eval_does_not_continue(self): @@ -299,7 +304,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          actual = await self.cog.continue_eval(ctx, MockMessage())          self.assertEqual(actual, None) -        ctx.message.clear_reactions.assert_called_once() +        ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)      async def test_get_code(self):          """Should return 1st arg (or None) if eval cmd in message, otherwise return full content.""" diff --git a/tests/bot/patches/__init__.py b/tests/bot/patches/__init__.py deleted file mode 100644 index e69de29bb..000000000 --- a/tests/bot/patches/__init__.py +++ /dev/null diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py index 9a72723e2..66c2d9f92 100644 --- a/tests/bot/rules/test_discord_emojis.py +++ b/tests/bot/rules/test_discord_emojis.py @@ -5,11 +5,12 @@ from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage  discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> +unicode_emoji = "🧪" -def make_msg(author: str, n_emojis: int) -> MockMessage: +def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage:      """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" -    return MockMessage(author=author, content=discord_emoji * n_emojis) +    return MockMessage(author=author, content=emoji * n_emojis)  class DiscordEmojisRuleTests(RuleTest): @@ -20,16 +21,22 @@ class DiscordEmojisRuleTests(RuleTest):          self.config = {"max": 2, "interval": 10}      async def test_allows_messages_within_limit(self): -        """Cases with a total amount of discord emojis within limit.""" +        """Cases with a total amount of discord and unicode emojis within limit."""          cases = (              [make_msg("bob", 2)],              [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], +            [make_msg("bob", 2, unicode_emoji)], +            [ +                make_msg("alice", 1, unicode_emoji), +                make_msg("bob", 2, unicode_emoji), +                make_msg("alice", 1, unicode_emoji) +            ],          )          await self.run_allowed(cases)      async def test_disallows_messages_beyond_limit(self): -        """Cases with more than the allowed amount of discord emojis.""" +        """Cases with more than the allowed amount of discord and unicode emojis."""          cases = (              DisallowedCase(                  [make_msg("bob", 3)], @@ -41,6 +48,20 @@ class DiscordEmojisRuleTests(RuleTest):                  ("alice",),                  4,              ), +            DisallowedCase( +                [make_msg("bob", 3, unicode_emoji)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [ +                    make_msg("alice", 2, unicode_emoji), +                    make_msg("bob", 2, unicode_emoji), +                    make_msg("alice", 2, unicode_emoji) +                ], +                ("alice",), +                4 +            )          )          await self.run_disallowed(cases) diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 5e0855704..1b48f6560 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -5,11 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch  from aiohttp import ClientConnectorError  from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from tests.helpers import MockBot  class PasteTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None: -        self.http_session = MagicMock() +        patcher = patch("bot.instance", new=MockBot()) +        self.bot = patcher.start() +        self.addCleanup(patcher.stop)      @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")      async def test_url_and_sent_contents(self): @@ -17,10 +20,10 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):          response = MagicMock(              json=AsyncMock(return_value={"key": ""})          ) -        self.http_session.post().__aenter__.return_value = response -        self.http_session.post.reset_mock() -        await send_to_paste_service(self.http_session, "Content") -        self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") +        self.bot.http_session.post.return_value.__aenter__.return_value = response +        self.bot.http_session.post.reset_mock() +        await send_to_paste_service("Content") +        self.bot.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content")      @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")      async def test_paste_returns_correct_url_on_success(self): @@ -34,41 +37,41 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):          response = MagicMock(              json=AsyncMock(return_value={"key": key})          ) -        self.http_session.post().__aenter__.return_value = response +        self.bot.http_session.post.return_value.__aenter__.return_value = response          for expected_output, extension in test_cases:              with self.subTest(msg=f"Send contents with extension {repr(extension)}"):                  self.assertEqual( -                    await send_to_paste_service(self.http_session, "", extension=extension), +                    await send_to_paste_service("", extension=extension),                      expected_output                  )      async def test_request_repeated_on_json_errors(self):          """Json with error message and invalid json are handled as errors and requests repeated."""          test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) -        self.http_session.post().__aenter__.return_value = response = MagicMock() -        self.http_session.post.reset_mock() +        self.bot.http_session.post.return_value.__aenter__.return_value = response = MagicMock() +        self.bot.http_session.post.reset_mock()          for error_json in test_cases:              with self.subTest(error_json=error_json):                  response.json = AsyncMock(return_value=error_json) -                result = await send_to_paste_service(self.http_session, "") -                self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +                result = await send_to_paste_service("") +                self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)                  self.assertIsNone(result) -            self.http_session.post.reset_mock() +            self.bot.http_session.post.reset_mock()      async def test_request_repeated_on_connection_errors(self):          """Requests are repeated in the case of connection errors.""" -        self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) -        result = await send_to_paste_service(self.http_session, "") -        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) +        result = await send_to_paste_service("") +        self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)          self.assertIsNone(result)      async def test_general_error_handled_and_request_repeated(self):          """All `Exception`s are handled, logged and request repeated.""" -        self.http_session.post = MagicMock(side_effect=Exception) -        result = await send_to_paste_service(self.http_session, "") -        self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) +        self.bot.http_session.post = MagicMock(side_effect=Exception) +        result = await send_to_paste_service("") +        self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)          self.assertLogs("bot.utils", logging.ERROR)          self.assertIsNone(result) diff --git a/tests/helpers.py b/tests/helpers.py index e47fdf28f..870f66197 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools  import logging  import unittest.mock  from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional  import discord  from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context  from bot.api import APIClient  from bot.async_stats import AsyncStatsClient  from bot.bot import Bot +from tests._autospec import autospec  # noqa: F401 other modules import it via this module  for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: -    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" -    # Caller's kwargs should take priority and overwrite the defaults. -    kwargs = {'spec_set': True, 'autospec': True, **kwargs} - -    # Import the target if it's a string. -    # This is to support both object and string targets like patch.multiple. -    if type(target) is str: -        target = unittest.mock._importer(target) - -    def decorator(func): -        for attribute in attributes: -            patcher = unittest.mock.patch.object(target, attribute, **kwargs) -            func = patcher(func) -        return func -    return decorator - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. | 
