aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2020-11-10 01:42:51 +0200
committerGravatar GitHub <[email protected]>2020-11-10 01:42:51 +0200
commit690f906c613ce90da403fc7a8942dc4d24cea000 (patch)
treed3dac1ccaef94a1d5499322f1b2862ae3c51cbbf /tests
parentAutomatically add periods as visual buffers (diff)
parentMerge pull request #1273 from python-discord/sebastiaan/bugs/codeblock-langua... (diff)
Merge branch 'master' into pure/feature/infraction-append
Diffstat (limited to 'tests')
-rw-r--r--tests/_autospec.py64
-rw-r--r--tests/bot/exts/backend/sync/test_base.py359
-rw-r--r--tests/bot/exts/backend/sync/test_cog.py4
-rw-r--r--tests/bot/exts/backend/sync/test_users.py120
-rw-r--r--tests/bot/exts/filters/test_token_remover.py150
-rw-r--r--tests/bot/exts/info/test_information.py175
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py148
-rw-r--r--tests/bot/exts/moderation/test_silence.py587
-rw-r--r--tests/bot/exts/utils/test_snekbox.py21
-rw-r--r--tests/bot/patches/__init__.py0
-rw-r--r--tests/helpers.py21
11 files changed, 919 insertions, 730 deletions
diff --git a/tests/_autospec.py b/tests/_autospec.py
new file mode 100644
index 000000000..ee2fc1973
--- /dev/null
+++ b/tests/_autospec.py
@@ -0,0 +1,64 @@
+import contextlib
+import functools
+import unittest.mock
+from typing import Callable
+
+
[email protected](unittest.mock._patch.decoration_helper)
+def _decoration_helper(self, patched, args, keywargs):
+ """Skips adding patchings as args if their `dont_pass` attribute is True."""
+ # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added.
+ extra_args = []
+ with contextlib.ExitStack() as exit_stack:
+ for patching in patched.patchings:
+ arg = exit_stack.enter_context(patching)
+ if not getattr(patching, "dont_pass", False):
+ # Only add the patching as an arg if dont_pass is False.
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is unittest.mock.DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield args, keywargs
+
+
[email protected](unittest.mock._patch.copy)
+def _copy(self):
+ """Copy the `dont_pass` attribute along with the standard copy operation."""
+ patcher_copy = _copy.original(self)
+ patcher_copy.dont_pass = getattr(self, "dont_pass", False)
+ return patcher_copy
+
+
+# Monkey-patch the patcher class :)
+_copy.original = unittest.mock._patch.copy
+unittest.mock._patch.copy = _copy
+unittest.mock._patch.decoration_helper = _decoration_helper
+
+
+def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable:
+ """
+ Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.
+
+ If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object.
+ """
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = dict(spec_set=True, autospec=True)
+ kwargs.update(patch_kwargs)
+
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ if not pass_mocks:
+ # A custom attribute to keep track of which patchings should be skipped.
+ patcher.dont_pass = True
+ func = patcher(func)
+ return func
+ return decorator
diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py
index 886c243cf..4953550f9 100644
--- a/tests/bot/exts/backend/sync/test_base.py
+++ b/tests/bot/exts/backend/sync/test_base.py
@@ -1,12 +1,9 @@
-import asyncio
import unittest
from unittest import mock
-import discord
-from bot import constants
from bot.api import ResponseCodeError
-from bot.exts.backend.sync._syncers import Syncer, _Diff
+from bot.exts.backend.sync._syncers import Syncer
from tests import helpers
@@ -30,280 +27,16 @@ class SyncerBaseTests(unittest.TestCase):
Syncer(self.bot)
-class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):
- """Tests for sending the sync confirmation prompt."""
-
- def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = TestSyncer(self.bot)
-
- def mock_get_channel(self):
- """Fixture to return a mock channel and message for when `get_channel` is used."""
- self.bot.reset_mock()
-
- mock_channel = helpers.MockTextChannel()
- mock_message = helpers.MockMessage()
-
- mock_channel.send.return_value = mock_message
- self.bot.get_channel.return_value = mock_channel
-
- return mock_channel, mock_message
-
- def mock_fetch_channel(self):
- """Fixture to return a mock channel and message for when `fetch_channel` is used."""
- self.bot.reset_mock()
-
- mock_channel = helpers.MockTextChannel()
- mock_message = helpers.MockMessage()
-
- self.bot.get_channel.return_value = None
- mock_channel.send.return_value = mock_message
- self.bot.fetch_channel.return_value = mock_channel
-
- return mock_channel, mock_message
-
- async def test_send_prompt_edits_and_returns_message(self):
- """The given message should be edited to display the prompt and then should be returned."""
- msg = helpers.MockMessage()
- ret_val = await self.syncer._send_prompt(msg)
-
- msg.edit.assert_called_once()
- self.assertIn("content", msg.edit.call_args[1])
- self.assertEqual(ret_val, msg)
-
- async def test_send_prompt_gets_dev_core_channel(self):
- """The dev-core channel should be retrieved if an extant message isn't given."""
- subtests = (
- (self.bot.get_channel, self.mock_get_channel),
- (self.bot.fetch_channel, self.mock_fetch_channel),
- )
-
- for method, mock_ in subtests:
- with self.subTest(method=method, msg=mock_.__name__):
- mock_()
- await self.syncer._send_prompt()
-
- method.assert_called_once_with(constants.Channels.dev_core)
-
- async def test_send_prompt_returns_none_if_channel_fetch_fails(self):
- """None should be returned if there's an HTTPException when fetching the channel."""
- self.bot.get_channel.return_value = None
- self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
-
- ret_val = await self.syncer._send_prompt()
-
- self.assertIsNone(ret_val)
-
- async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
- """A new message mentioning core devs should be sent and returned if message isn't given."""
- for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
- with self.subTest(msg=mock_.__name__):
- mock_channel, mock_message = mock_()
- ret_val = await self.syncer._send_prompt()
-
- mock_channel.send.assert_called_once()
- self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
- self.assertEqual(ret_val, mock_message)
-
- async def test_send_prompt_adds_reactions(self):
- """The message should have reactions for confirmation added."""
- extant_message = helpers.MockMessage()
- subtests = (
- (extant_message, lambda: (None, extant_message)),
- (None, self.mock_get_channel),
- (None, self.mock_fetch_channel),
- )
-
- for message_arg, mock_ in subtests:
- subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__
-
- with self.subTest(msg=subtest_msg):
- _, mock_message = mock_()
- await self.syncer._send_prompt(message_arg)
-
- calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS]
- mock_message.add_reaction.assert_has_calls(calls)
-
-
-class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):
- """Tests for waiting for a sync confirmation reaction on the prompt."""
-
- def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = TestSyncer(self.bot)
- self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers)
-
- @staticmethod
- def get_message_reaction(emoji):
- """Fixture to return a mock message an reaction from the given `emoji`."""
- message = helpers.MockMessage()
- reaction = helpers.MockReaction(emoji=emoji, message=message)
-
- return message, reaction
-
- def test_reaction_check_for_valid_emoji_and_authors(self):
- """Should return True if authors are identical or are a bot and a core dev, respectively."""
- user_subtests = (
- (
- helpers.MockMember(id=77),
- helpers.MockMember(id=77),
- "identical users",
- ),
- (
- helpers.MockMember(id=77, bot=True),
- helpers.MockMember(id=43, roles=[self.core_dev_role]),
- "bot author and core-dev reactor",
- ),
- )
-
- for emoji in self.syncer._REACTION_EMOJIS:
- for author, user, msg in user_subtests:
- with self.subTest(author=author, user=user, emoji=emoji, msg=msg):
- message, reaction = self.get_message_reaction(emoji)
- ret_val = self.syncer._reaction_check(author, message, reaction, user)
-
- self.assertTrue(ret_val)
-
- def test_reaction_check_for_invalid_reactions(self):
- """Should return False for invalid reaction events."""
- valid_emoji = self.syncer._REACTION_EMOJIS[0]
- subtests = (
- (
- helpers.MockMember(id=77),
- *self.get_message_reaction(valid_emoji),
- helpers.MockMember(id=43, roles=[self.core_dev_role]),
- "users are not identical",
- ),
- (
- helpers.MockMember(id=77, bot=True),
- *self.get_message_reaction(valid_emoji),
- helpers.MockMember(id=43),
- "reactor lacks the core-dev role",
- ),
- (
- helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
- *self.get_message_reaction(valid_emoji),
- helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
- "reactor is a bot",
- ),
- (
- helpers.MockMember(id=77),
- helpers.MockMessage(id=95),
- helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)),
- helpers.MockMember(id=77),
- "messages are not identical",
- ),
- (
- helpers.MockMember(id=77),
- *self.get_message_reaction("InVaLiD"),
- helpers.MockMember(id=77),
- "emoji is invalid",
- ),
- )
-
- for *args, msg in subtests:
- kwargs = dict(zip(("author", "message", "reaction", "user"), args))
- with self.subTest(**kwargs, msg=msg):
- ret_val = self.syncer._reaction_check(*args)
- self.assertFalse(ret_val)
-
- async def test_wait_for_confirmation(self):
- """The message should always be edited and only return True if the emoji is a check mark."""
- subtests = (
- (constants.Emojis.check_mark, True, None),
- ("InVaLiD", False, None),
- (None, False, asyncio.TimeoutError),
- )
-
- for emoji, ret_val, side_effect in subtests:
- for bot in (True, False):
- with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot):
- # Set up mocks
- message = helpers.MockMessage()
- member = helpers.MockMember(bot=bot)
-
- self.bot.wait_for.reset_mock()
- self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None)
- self.bot.wait_for.side_effect = side_effect
-
- # Call the function
- actual_return = await self.syncer._wait_for_confirmation(member, message)
-
- # Perform assertions
- self.bot.wait_for.assert_called_once()
- self.assertIn("reaction_add", self.bot.wait_for.call_args[0])
-
- message.edit.assert_called_once()
- kwargs = message.edit.call_args[1]
- self.assertIn("content", kwargs)
-
- # Core devs should only be mentioned if the author is a bot.
- if bot:
- self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
- else:
- self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
-
- self.assertIs(actual_return, ret_val)
-
-
class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for main function orchestrating the sync."""
def setUp(self):
self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
self.syncer = TestSyncer(self.bot)
+ self.guild = helpers.MockGuild()
- async def test_sync_respects_confirmation_result(self):
- """The sync should abort if confirmation fails and continue if confirmed."""
- mock_message = helpers.MockMessage()
- subtests = (
- (True, mock_message),
- (False, None),
- )
-
- for confirmed, message in subtests:
- with self.subTest(confirmed=confirmed):
- self.syncer._sync.reset_mock()
- self.syncer._get_diff.reset_mock()
-
- diff = _Diff({1, 2, 3}, {4, 5}, None)
- self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = mock.AsyncMock(
- return_value=(confirmed, message)
- )
-
- guild = helpers.MockGuild()
- await self.syncer.sync(guild)
-
- self.syncer._get_diff.assert_called_once_with(guild)
- self.syncer._get_confirmation_result.assert_called_once()
-
- if confirmed:
- self.syncer._sync.assert_called_once_with(diff)
- else:
- self.syncer._sync.assert_not_called()
-
- async def test_sync_diff_size(self):
- """The diff size should be correctly calculated."""
- subtests = (
- (6, _Diff({1, 2}, {3, 4}, {5, 6})),
- (5, _Diff({1, 2, 3}, None, {4, 5})),
- (0, _Diff(None, None, None)),
- (0, _Diff(set(), set(), set())),
- )
-
- for size, diff in subtests:
- with self.subTest(size=size, diff=diff):
- self.syncer._get_diff.reset_mock()
- self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
-
- guild = helpers.MockGuild()
- await self.syncer.sync(guild)
-
- self.syncer._get_diff.assert_called_once_with(guild)
- self.syncer._get_confirmation_result.assert_called_once()
- self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
+ # Make sure `_get_diff` returns a MagicMock, not an AsyncMock
+ self.syncer._get_diff.return_value = mock.MagicMock()
async def test_sync_message_edited(self):
"""The message should be edited if one was sent, even if the sync has an API error."""
@@ -316,89 +49,25 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
for message, side_effect, should_edit in subtests:
with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
self.syncer._sync.side_effect = side_effect
- self.syncer._get_confirmation_result = mock.AsyncMock(
- return_value=(True, message)
- )
+ ctx = helpers.MockContext()
+ ctx.send.return_value = message
- guild = helpers.MockGuild()
- await self.syncer.sync(guild)
+ await self.syncer.sync(self.guild, ctx)
if should_edit:
message.edit.assert_called_once()
self.assertIn("content", message.edit.call_args[1])
- async def test_sync_confirmation_context_redirect(self):
- """If ctx is given, a new message should be sent and author should be ctx's author."""
- mock_member = helpers.MockMember()
+ async def test_sync_message_sent(self):
+ """If ctx is given, a new message should be sent."""
subtests = (
- (None, self.bot.user, None),
- (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()),
+ (None, None),
+ (helpers.MockContext(), helpers.MockMessage()),
)
- for ctx, author, message in subtests:
- with self.subTest(ctx=ctx, author=author, message=message):
- if ctx is not None:
- ctx.send.return_value = message
-
- # Make sure `_get_diff` returns a MagicMock, not an AsyncMock
- self.syncer._get_diff.return_value = mock.MagicMock()
-
- self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
-
- guild = helpers.MockGuild()
- await self.syncer.sync(guild, ctx)
+ for ctx, message in subtests:
+ with self.subTest(ctx=ctx, message=message):
+ await self.syncer.sync(self.guild, ctx)
if ctx is not None:
ctx.send.assert_called_once()
-
- self.syncer._get_confirmation_result.assert_called_once()
- self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author)
- self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
-
- @mock.patch.object(constants.Sync, "max_diff", new=3)
- async def test_confirmation_result_small_diff(self):
- """Should always return True and the given message if the diff size is too small."""
- author = helpers.MockMember()
- expected_message = helpers.MockMessage()
-
- for size in (3, 2): # pragma: no cover
- with self.subTest(size=size):
- self.syncer._send_prompt = mock.AsyncMock()
- self.syncer._wait_for_confirmation = mock.AsyncMock()
-
- coro = self.syncer._get_confirmation_result(size, author, expected_message)
- result, actual_message = await coro
-
- self.assertTrue(result)
- self.assertEqual(actual_message, expected_message)
- self.syncer._send_prompt.assert_not_called()
- self.syncer._wait_for_confirmation.assert_not_called()
-
- @mock.patch.object(constants.Sync, "max_diff", new=3)
- async def test_confirmation_result_large_diff(self):
- """Should return True if confirmed and False if _send_prompt fails or aborted."""
- author = helpers.MockMember()
- mock_message = helpers.MockMessage()
-
- subtests = (
- (True, mock_message, True, "confirmed"),
- (False, None, False, "_send_prompt failed"),
- (False, mock_message, False, "aborted"),
- )
-
- for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover
- with self.subTest(msg=msg):
- self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message)
- self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)
-
- coro = self.syncer._get_confirmation_result(4, author)
- actual_result, actual_message = await coro
-
- self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None
- self.assertIs(actual_result, expected_result)
- self.assertEqual(actual_message, expected_message)
-
- if expected_message:
- self.syncer._wait_for_confirmation.assert_called_once_with(
- author, expected_message
- )
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py
index 1b89564f2..063a82754 100644
--- a/tests/bot/exts/backend/sync/test_cog.py
+++ b/tests/bot/exts/backend/sync/test_cog.py
@@ -392,14 +392,14 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
async def test_sync_roles_command(self):
"""sync() should be called on the RoleSyncer."""
ctx = helpers.MockContext()
- await self.cog.sync_roles_command.callback(self.cog, ctx)
+ await self.cog.sync_roles_command(self.cog, ctx)
self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
async def test_sync_users_command(self):
"""sync() should be called on the UserSyncer."""
ctx = helpers.MockContext()
- await self.cog.sync_users_command.callback(self.cog, ctx)
+ await self.cog.sync_users_command(self.cog, ctx)
self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index c0a1da35c..9f380a15d 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,7 +1,6 @@
import unittest
-from unittest import mock
-from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User
+from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -10,7 +9,7 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("roles", [666])
kwargs.setdefault("in_guild", True)
return kwargs
@@ -40,22 +39,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
return guild
+ @staticmethod
+ def get_mock_member(member: dict):
+ member = member.copy()
+ del member["in_guild"]
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+ return mock_member
+
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": []
+ }
guild = self.get_guild()
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.return_value = self.get_mock_member(fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -63,59 +82,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(id=99, name="old"), fake_user()]
+ }
guild = self.get_guild(updated_user, fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(updated_user),
+ self.get_mock_member(fake_user())
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**updated_user)}, None)
+ expected_diff = ([], [{"id": 99, "name": "new"}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_users(self):
- """Only new users should be added to the 'created' set of the diff."""
+ """Only new users should be added to the 'created' list of the diff."""
new_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user(), new_user)
-
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(new_user)
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, set(), None)
+ expected_diff = ([new_user], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- leaving_user = fake_user(id=63, in_guild=False)
-
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**leaving_user)}, None)
+ expected_diff = ([], [{"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
+
updated_user = fake_user(id=55, name="updated")
- leaving_user = fake_user(id=63, in_guild=False)
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=55), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user(), new_user, updated_user)
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(updated_user),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+ expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_db_users_not_in_guild(self):
- """When the DB knows a user the guild doesn't, no difference is found."""
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ """When the DB knows a user, but the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63, in_guild=False)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -131,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(user_tuples, set(), None)
+ diff = _Diff(users, [], None)
await self.syncer._sync(diff)
- calls = [mock.call("bot/users", json=user) for user in users]
- self.bot.api_client.post.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.post.call_count, len(users))
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
@@ -146,13 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(set(), user_tuples, None)
+ diff = _Diff([], users, None)
await self.syncer._sync(diff)
- calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
- self.bot.api_client.put.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.put.call_count, len(users))
+ self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py
index ea822053b..f99cc3370 100644
--- a/tests/bot/exts/filters/test_token_remover.py
+++ b/tests/bot/exts/filters/test_token_remover.py
@@ -23,23 +23,25 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
self.msg = MockMessage(id=555, content="hello world")
self.msg.channel.mention = "#lemonade-stand"
+ self.msg.guild.get_member.return_value.bot = False
+ self.msg.guild.get_member.return_value.__str__.return_value = "Woody"
self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
self.msg.author.avatar_url_as.return_value = "picture-lemon.png"
- def test_is_valid_user_id_valid(self):
- """Should consider user IDs valid if they decode entirely to ASCII digits."""
- ids = (
- "NDcyMjY1OTQzMDYyNDEzMzMy",
- "NDc1MDczNjI5Mzk5NTQ3OTA0",
- "NDY3MjIzMjMwNjUwNzc3NjQx",
+ def test_extract_user_id_valid(self):
+ """Should consider user IDs valid if they decode into an integer ID."""
+ id_pairs = (
+ ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332),
+ ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904),
+ ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641),
)
- for user_id in ids:
- with self.subTest(user_id=user_id):
- result = TokenRemover.is_valid_user_id(user_id)
- self.assertTrue(result)
+ for token_id, user_id in id_pairs:
+ with self.subTest(token_id=token_id):
+ result = TokenRemover.extract_user_id(token_id)
+ self.assertEqual(result, user_id)
- def test_is_valid_user_id_invalid(self):
+ def test_extract_user_id_invalid(self):
"""Should consider non-digit and non-ASCII IDs invalid."""
ids = (
("SGVsbG8gd29ybGQ", "non-digit ASCII"),
@@ -53,8 +55,8 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
for user_id, msg in ids:
with self.subTest(msg=msg):
- result = TokenRemover.is_valid_user_id(user_id)
- self.assertFalse(result)
+ result = TokenRemover.extract_user_id(user_id)
+ self.assertIsNone(result)
def test_is_valid_timestamp_valid(self):
"""Should consider timestamps valid if they're greater than the Discord epoch."""
@@ -86,6 +88,34 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
result = TokenRemover.is_valid_timestamp(timestamp)
self.assertFalse(result)
+ def test_is_valid_hmac_valid(self):
+ """Should consider an HMAC valid if it has at least 3 unique characters."""
+ valid_hmacs = (
+ "VXmErH7j511turNpfURmb0rVNm8",
+ "Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "sJf6omBPORBPju3WJEIAcwW9Zds",
+ "s45jqDV_Iisn-symw0yDRrk_jf4",
+ )
+
+ for hmac in valid_hmacs:
+ with self.subTest(msg=hmac):
+ result = TokenRemover.is_maybe_valid_hmac(hmac)
+ self.assertTrue(result)
+
+ def test_is_invalid_hmac_invalid(self):
+ """Should consider an HMAC invalid if has fewer than 3 unique characters."""
+ invalid_hmacs = (
+ ("xxxxxxxxxxxxxxxxxx", "Single character"),
+ ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"),
+ ("ASFasfASFasfASFASsf", "Three characters alternating-case"),
+ ("asdasdasdasdasdasdasd", "Three characters one case"),
+ )
+
+ for hmac, msg in invalid_hmacs:
+ with self.subTest(msg=msg):
+ result = TokenRemover.is_maybe_valid_hmac(hmac)
+ self.assertFalse(result)
+
def test_mod_log_property(self):
"""The `mod_log` property should ask the bot to return the `ModLog` cog."""
self.bot.get_cog.return_value = 'lemon'
@@ -143,11 +173,18 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
self.assertIsNone(return_value)
token_re.finditer.assert_called_once_with(self.msg.content)
- @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
@autospec("bot.exts.filters.token_remover", "Token")
@autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_valid_match(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
- """The first match with a valid user ID and timestamp should be returned as a `Token`."""
+ def test_find_token_valid_match(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`."""
matches = [
mock.create_autospec(Match, spec_set=True, instance=True),
mock.create_autospec(Match, spec_set=True, instance=True),
@@ -159,23 +196,32 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
token_re.finditer.return_value = matches
token_cls.side_effect = tokens
- is_valid_id.side_effect = (False, True) # The 1st match will be invalid, 2nd one valid.
+ extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid.
is_valid_timestamp.return_value = True
+ is_maybe_valid_hmac.return_value = True
return_value = TokenRemover.find_token_in_message(self.msg)
self.assertEqual(tokens[1], return_value)
token_re.finditer.assert_called_once_with(self.msg.content)
- @autospec(TokenRemover, "is_valid_user_id", "is_valid_timestamp")
+ @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
@autospec("bot.exts.filters.token_remover", "Token")
@autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_invalid_matches(self, token_re, token_cls, is_valid_id, is_valid_timestamp):
- """None should be returned if no matches have valid user IDs or timestamps."""
+ def test_find_token_invalid_matches(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """None should be returned if no matches have valid user IDs, HMACs, and timestamps."""
token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
- is_valid_id.return_value = False
+ extract_user_id.return_value = None
is_valid_timestamp.return_value = False
+ is_maybe_valid_hmac.return_value = False
return_value = TokenRemover.find_token_in_message(self.msg)
@@ -234,7 +280,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
@autospec("bot.exts.filters.token_remover", "LOG_MESSAGE")
def test_format_log_message(self, log_message):
"""Should correctly format the log message with info from the message and token."""
- token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
log_message.format.return_value = "Howdy"
return_value = TokenRemover.format_log_message(self.msg, token)
@@ -248,18 +294,68 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
hmac="x" * len(token.hmac),
)
+ @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE")
+ def test_format_userid_log_message_unknown(self, unknown_user_log_message):
+ """Should correctly format the user ID portion when the actual user it belongs to is unknown."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ unknown_user_log_message.format.return_value = " Partner"
+ msg = MockMessage(id=555, content="hello world")
+ msg.guild.get_member.return_value = None
+
+ return_value = TokenRemover.format_userid_log_message(msg, token)
+
+ self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
+ unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332)
+
+ @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
+ def test_format_userid_log_message_bot(self, known_user_log_message):
+ """Should correctly format the user ID portion when the ID belongs to a known bot."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ known_user_log_message.format.return_value = " Partner"
+ msg = MockMessage(id=555, content="hello world")
+ msg.guild.get_member.return_value.__str__.return_value = "Sam"
+ msg.guild.get_member.return_value.bot = True
+
+ return_value = TokenRemover.format_userid_log_message(msg, token)
+
+ self.assertEqual(return_value, (known_user_log_message.format.return_value, False))
+
+ known_user_log_message.format.assert_called_once_with(
+ user_id=472265943062413332,
+ user_name="Sam",
+ kind="BOT",
+ )
+
+ @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
+ def test_format_log_message_user_token_user(self, user_token_message):
+ """Should correctly format the user ID portion when the ID belongs to a known user."""
+ token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ user_token_message.format.return_value = "Partner"
+
+ return_value = TokenRemover.format_userid_log_message(self.msg, token)
+
+ self.assertEqual(return_value, (user_token_message.format.return_value, True))
+ user_token_message.format.assert_called_once_with(
+ user_id=467223230650777641,
+ user_name="Woody",
+ kind="USER",
+ )
+
@mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
@autospec("bot.exts.filters.token_remover", "log")
- @autospec(TokenRemover, "format_log_message")
- async def test_take_action(self, format_log_message, logger, mod_log_property):
+ @autospec(TokenRemover, "format_log_message", "format_userid_log_message")
+ async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property):
"""Should delete the message and send a mod log."""
cog = TokenRemover(self.bot)
mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True)
token = mock.create_autospec(Token, spec_set=True, instance=True)
+ token.user_id = "no-id"
log_msg = "testing123"
+ userid_log_message = "userid-log-message"
mod_log_property.return_value = mod_log
format_log_message.return_value = log_msg
+ format_userid_log_message.return_value = (userid_log_message, True)
await cog.take_action(self.msg, token)
@@ -269,6 +365,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
)
format_log_message.assert_called_once_with(self.msg, token)
+ format_userid_log_message.assert_called_once_with(self.msg, token)
logger.debug.assert_called_with(log_msg)
self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens")
@@ -277,9 +374,10 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
icon_url=constants.Icons.token_removed,
colour=Colour(constants.Colours.soft_red),
title="Token removed!",
- text=log_msg,
+ text=log_msg + "\n" + userid_log_message,
thumbnail=self.msg.author.avatar_url_as.return_value,
- channel_id=constants.Channels.mod_alerts
+ channel_id=constants.Channels.mod_alerts,
+ ping_everyone=True,
)
@mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index d3f2995fb..daede54c5 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -1,4 +1,3 @@
-import asyncio
import textwrap
import unittest
import unittest.mock
@@ -13,7 +12,7 @@ from tests import helpers
COG_PATH = "bot.exts.info.information.Information"
-class InformationCogTests(unittest.TestCase):
+class InformationCogTests(unittest.IsolatedAsyncioTestCase):
"""Tests the Information cog."""
@classmethod
@@ -29,16 +28,14 @@ class InformationCogTests(unittest.TestCase):
self.ctx = helpers.MockContext()
self.ctx.author.roles.append(self.moderator_role)
- def test_roles_command_command(self):
+ async def test_roles_command_command(self):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
- coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
-
- self.assertIsNone(asyncio.run(coroutine))
+ self.assertIsNone(await self.cog.roles_info(self.cog, self.ctx))
self.ctx.send.assert_called_once()
_, kwargs = self.ctx.send.call_args
@@ -48,7 +45,7 @@ class InformationCogTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
- def test_role_info_command(self):
+ async def test_role_info_command(self):
"""Tests the `role info` command."""
dummy_role = helpers.MockRole(
name="Dummy",
@@ -73,9 +70,7 @@ class InformationCogTests(unittest.TestCase):
self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
- coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
-
- self.assertIsNone(asyncio.run(coroutine))
+ self.assertIsNone(await self.cog.role_info(self.cog, self.ctx, dummy_role, admin_role))
self.assertEqual(self.ctx.send.call_count, 2)
@@ -97,80 +92,8 @@ class InformationCogTests(unittest.TestCase):
self.assertEqual(admin_embed.title, "Admins info")
self.assertEqual(admin_embed.colour, discord.Colour.red())
- @unittest.mock.patch('bot.exts.info.information.time_since')
- def test_server_info_command(self, time_since_patch):
- time_since_patch.return_value = '2 days ago'
-
- self.ctx.guild = helpers.MockGuild(
- features=('lemons', 'apples'),
- region="The Moon",
- roles=[self.moderator_role],
- channels=[
- discord.TextChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'}
- ),
- discord.CategoryChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'}
- ),
- discord.VoiceChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'}
- )
- ],
- members=[
- *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
- *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
- *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
- *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
- ],
- member_count=1_234,
- icon_url='a-lemon.jpg',
- )
-
- coroutine = self.cog.server_info.callback(self.cog, self.ctx)
- self.assertIsNone(asyncio.run(coroutine))
-
- time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days')
- _, kwargs = self.ctx.send.call_args
- embed = kwargs.pop('embed')
- self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(
- embed.description,
- textwrap.dedent(
- f"""
- **Server information**
- Created: {time_since_patch.return_value}
- Voice region: {self.ctx.guild.region}
- Features: {', '.join(self.ctx.guild.features)}
-
- **Channel counts**
- Category channels: 1
- Text channels: 1
- Voice channels: 1
- Staff channels: 0
-
- **Member counts**
- Members: {self.ctx.guild.member_count:,}
- Staff members: 0
- Roles: {len(self.ctx.guild.roles)}
-
- **Member statuses**
- {constants.Emojis.status_online} 2
- {constants.Emojis.status_idle} 1
- {constants.Emojis.status_dnd} 4
- {constants.Emojis.status_offline} 3
- """
- )
- )
- self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg')
-
-class UserInfractionHelperMethodTests(unittest.TestCase):
+class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the helper methods of the `!user` command."""
def setUp(self):
@@ -180,7 +103,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
- def test_user_command_helper_method_get_requests(self):
+ async def test_user_command_helper_method_get_requests(self):
"""The helper methods should form the correct get requests."""
test_values = (
{
@@ -202,11 +125,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
endpoint, params = test_value["expected_args"]
with self.subTest(method=helper_method, endpoint=endpoint, params=params):
- asyncio.run(helper_method(self.member))
+ await helper_method(self.member)
self.bot.api_client.get.assert_called_once_with(endpoint, params=params)
self.bot.api_client.get.reset_mock()
- def _method_subtests(self, method, test_values, default_header):
+ async def _method_subtests(self, method, test_values, default_header):
"""Helper method that runs the subtests for the different helper methods."""
for test_value in test_values:
api_response = test_value["api response"]
@@ -216,11 +139,11 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
self.bot.api_client.get.return_value = api_response
expected_output = "\n".join(expected_lines)
- actual_output = asyncio.run(method(self.member))
+ actual_output = await method(self.member)
self.assertEqual((default_header, expected_output), actual_output)
- def test_basic_user_infraction_counts_returns_correct_strings(self):
+ async def test_basic_user_infraction_counts_returns_correct_strings(self):
"""The method should correctly list both the total and active number of non-hidden infractions."""
test_values = (
# No infractions means zero counts
@@ -251,9 +174,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Infractions"
- self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header)
+ await self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header)
- def test_expanded_user_infraction_counts_returns_correct_strings(self):
+ async def test_expanded_user_infraction_counts_returns_correct_strings(self):
"""The method should correctly list the total and active number of all infractions split by infraction type."""
test_values = (
{
@@ -306,9 +229,9 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Infractions"
- self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header)
+ await self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header)
- def test_user_nomination_counts_returns_correct_strings(self):
+ async def test_user_nomination_counts_returns_correct_strings(self):
"""The method should list the number of active and historical nominations for the user."""
test_values = (
{
@@ -336,12 +259,12 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
header = "Nominations"
- self._method_subtests(self.cog.user_nomination_counts, test_values, header)
+ await self._method_subtests(self.cog.user_nomination_counts, test_values, header)
@unittest.mock.patch("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))
@unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50])
-class UserEmbedTests(unittest.TestCase):
+class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the creation of the `!user` embed."""
def setUp(self):
@@ -354,14 +277,14 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
+ async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
user.nick = None
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.title, "Mr. Hemlock")
@@ -369,14 +292,14 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_nick_in_title_if_available(self):
+ async def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
user.nick = "Cat lover"
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
@@ -384,7 +307,7 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_ignores_everyone_role(self):
+ async def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
admins_role = helpers.MockRole(name='Admins')
@@ -393,14 +316,18 @@ class UserEmbedTests(unittest.TestCase):
# A `MockMember` has the @Everyone role by default; we add the Admins to that.
user = helpers.MockMember(roles=[admins_role], top_role=admins_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertIn("&Admins", embed.fields[1].value)
self.assertNotIn("&Everyone", embed.fields[1].value)
@unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
@unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
- def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
+ async def test_create_user_embed_expanded_information_in_moderation_channels(
+ self,
+ nomination_counts,
+ infraction_counts
+ ):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -411,7 +338,7 @@ class UserEmbedTests(unittest.TestCase):
nomination_counts.return_value = ("Nominations", "nomination info")
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
infraction_counts.assert_called_once_with(user)
nomination_counts.assert_called_once_with(user)
@@ -434,7 +361,7 @@ class UserEmbedTests(unittest.TestCase):
)
@unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
- def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
+ async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -444,7 +371,7 @@ class UserEmbedTests(unittest.TestCase):
infraction_counts.return_value = ("Infractions", "basic infractions info")
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
infraction_counts.assert_called_once_with(user)
@@ -467,14 +394,14 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(
"basic infractions info",
- embed.fields[3].value
+ embed.fields[2].value
)
@unittest.mock.patch(
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
+ async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -482,7 +409,7 @@ class UserEmbedTests(unittest.TestCase):
moderators_role.colour = 100
user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
@@ -490,12 +417,12 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
+ async def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217)
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
self.assertEqual(embed.colour, discord.Colour.blurple())
@@ -503,20 +430,20 @@ class UserEmbedTests(unittest.TestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
+ async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217)
user.avatar_url_as.return_value = "avatar url"
- embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+ embed = await self.cog.create_user_embed(ctx, user)
user.avatar_url_as.assert_called_once_with(static_format="png")
self.assertEqual(embed.thumbnail.url, "avatar url")
@unittest.mock.patch("bot.exts.info.information.constants")
-class UserCommandTests(unittest.TestCase):
+class UserCommandTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the `!user` command."""
def setUp(self):
@@ -536,16 +463,16 @@ class UserCommandTests(unittest.TestCase):
# used as a default value for a parameter, which gets defined upon import.
self.bot_command_channel = helpers.MockTextChannel(id=constants.Channels.bot_commands)
- def test_regular_member_cannot_target_another_member(self, constants):
+ async def test_regular_member_cannot_target_another_member(self, constants):
"""A regular user should not be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+ await self.cog.user_info(self.cog, ctx, self.target)
ctx.send.assert_called_once_with("You may not use this command on users other than yourself.")
- def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):
+ async def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):
"""A regular user should not be able to use this command outside of bot-commands."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
@@ -553,49 +480,49 @@ class UserCommandTests(unittest.TestCase):
msg = "Sorry, but you may only use this command within <#50>."
with self.assertRaises(InWhitelistCheckFailure, msg=msg):
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
+ async def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_regular_user_can_explicitly_target_themselves(self, create_embed, _):
+ async def test_regular_user_can_explicitly_target_themselves(self, create_embed, _):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.author, channel=self.bot_command_channel)
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author))
+ await self.cog.user_info(self.cog, ctx, self.author)
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
+ async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
- asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+ await self.cog.user_info(self.cog, ctx)
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
- def test_moderators_can_target_another_member(self, create_embed, constants):
+ async def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50))
- asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+ await self.cog.user_info(self.cog, ctx, self.target)
create_embed.assert_called_once_with(ctx, self.target)
ctx.send.assert_called_once()
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index be1b649e1..bf557a484 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -1,7 +1,8 @@
import textwrap
import unittest
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, patch
+from bot.constants import Event
from bot.exts.moderation.infraction.infractions import Infractions
from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole
@@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.apply_infraction.assert_awaited_once_with(
self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
)
+
+
+@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)
+class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for voice ban related functions and commands."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.mod = MockMember(top_role=10)
+ self.user = MockMember(top_role=1, roles=[MockRole(id=123456)])
+ self.guild = MockGuild()
+ self.ctx = MockContext(bot=self.bot, author=self.mod)
+ self.cog = Infractions(self.bot)
+
+ async def test_permanent_voice_ban(self):
+ """Should call voice ban applying function without expiry."""
+ self.cog.apply_voice_ban = AsyncMock()
+ self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar"))
+ self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar")
+
+ async def test_temporary_voice_ban(self):
+ """Should call voice ban applying function with expiry."""
+ self.cog.apply_voice_ban = AsyncMock()
+ self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar"))
+ self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz")
+
+ async def test_voice_unban(self):
+ """Should call infraction pardoning function."""
+ self.cog.pardon_infraction = AsyncMock()
+ self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user))
+ self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user)
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock):
+ """Should return early when user already have Voice Ban infraction."""
+ get_active_infraction.return_value = {"foo": "bar"}
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban")
+ post_infraction_mock.assert_not_awaited()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock):
+ """Should return early when posting infraction fails."""
+ self.cog.mod_log.ignore = MagicMock()
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = None
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ post_infraction_mock.assert_awaited_once()
+ self.cog.mod_log.ignore.assert_not_called()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock):
+ """Should pass all kwargs passed to apply_voice_ban to post_infraction."""
+ get_active_infraction.return_value = None
+ # We don't want that this continue yet
+ post_infraction_mock.return_value = None
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23))
+ post_infraction_mock.assert_awaited_once_with(
+ self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23
+ )
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock):
+ """Should ignore Voice Verified role removing."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id)
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock):
+ """Should ignore Voice Verified role removing."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar")
+ self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value")
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock):
+ """Should truncate reason for voice ban."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000))
+ self.user.remove_roles.assert_called_once_with(
+ self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...")
+ )
+ self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value")
+
+ async def test_voice_unban_user_not_found(self):
+ """Should include info to return dict when user was not found from guild."""
+ self.guild.get_member.return_value = None
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {"Info": "User was not found in the guild."})
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
+ @patch("bot.exts.moderation.infraction.infractions.format_user")
+ async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock):
+ """Should add role back with ignoring, notify user and return log dictionary.."""
+ self.guild.get_member.return_value = self.user
+ notify_pardon_mock.return_value = True
+ format_user_mock.return_value = "my-user"
+
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {
+ "Member": "my-user",
+ "DM": "Sent"
+ })
+ notify_pardon_mock.assert_awaited_once()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
+ @patch("bot.exts.moderation.infraction.infractions.format_user")
+ async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock):
+ """Should add role back with ignoring, notify user and return log dictionary.."""
+ self.guild.get_member.return_value = self.user
+ notify_pardon_mock.return_value = False
+ format_user_mock.return_value = "my-user"
+
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {
+ "Member": "my-user",
+ "DM": "**Failed**"
+ })
+ notify_pardon_mock.assert_awaited_once()
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index e2d44c637..104293d8e 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -1,23 +1,49 @@
+import asyncio
import unittest
+from datetime import datetime, timezone
from unittest import mock
-from unittest.mock import MagicMock, Mock
+from unittest.mock import Mock
+from async_rediscache import RedisSession
from discord import PermissionOverwrite
-from bot.constants import Channels, Emojis, Guild, Roles
-from bot.exts.moderation.silence import Silence, SilenceNotifier
-from tests.helpers import MockBot, MockContext, MockTextChannel
+from bot.constants import Channels, Guild, Roles
+from bot.exts.moderation import silence
+from tests.helpers import MockBot, MockContext, MockTextChannel, autospec
+
+redis_session = None
+redis_loop = asyncio.get_event_loop()
+
+
+def setUpModule(): # noqa: N802
+ """Create and connect to the fakeredis session."""
+ global redis_session
+ redis_session = RedisSession(use_fakeredis=True)
+ redis_loop.run_until_complete(redis_session.connect())
+
+
+def tearDownModule(): # noqa: N802
+ """Close the fakeredis session."""
+ if redis_session:
+ redis_loop.run_until_complete(redis_session.close())
+
+
+# Have to subclass it because builtins can't be patched.
+class PatchedDatetime(datetime):
+ """A datetime object with a mocked now() function."""
+
+ now = mock.create_autospec(datetime, "now")
class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.alert_channel = MockTextChannel()
- self.notifier = SilenceNotifier(self.alert_channel)
+ self.notifier = silence.SilenceNotifier(self.alert_channel)
self.notifier.stop = self.notifier_stop_mock = Mock()
self.notifier.start = self.notifier_start_mock = Mock()
def test_add_channel_adds_channel(self):
- """Channel in FirstHash with current loop is added to internal set."""
+ """Channel is added to `_silenced_channels` with the current loop."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.add_channel(channel)
@@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.notifier_start_mock.assert_not_called()
def test_remove_channel_removes_channel(self):
- """Channel in FirstHash is removed from `_silenced_channels`."""
+ """Channel is removed from `_silenced_channels`."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.remove_channel(channel)
@@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
with self.subTest(current_loop=current_loop):
with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
await self.notifier._notifier()
- self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ")
+ self.alert_channel.send.assert_called_once_with(
+ f"<@&{Roles.moderators}> currently silenced channels: "
+ )
self.alert_channel.send.reset_mock()
async def test_notifier_skips_alert(self):
@@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.alert_channel.send.assert_not_called()
-class SilenceTests(unittest.IsolatedAsyncioTestCase):
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the general functionality of the Silence cog."""
+
+ @autospec(silence, "Scheduler", pass_mocks=False)
def setUp(self) -> None:
self.bot = MockBot()
- self.cog = Silence(self.bot)
- self.ctx = MockContext()
- self.cog._verified_role = None
- # Set event so command callbacks can continue.
- self.cog._get_instance_vars_event.set()
+ self.cog = silence.Silence(self.bot)
- async def test_instance_vars_got_guild(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_guild(self):
"""Bot got guild after it became available."""
- await self.cog._get_instance_vars()
- self.bot.wait_until_guild_available.assert_called_once()
+ await self.cog._async_init()
+ self.bot.wait_until_guild_available.assert_awaited_once()
self.bot.get_guild.assert_called_once_with(Guild.id)
- async def test_instance_vars_got_role(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_role(self):
"""Got `Roles.verified` role from guild."""
- await self.cog._get_instance_vars()
guild = self.bot.get_guild()
- guild.get_role.assert_called_once_with(Roles.verified)
+ guild.get_role.side_effect = lambda id_: Mock(id=id_)
- async def test_instance_vars_got_channels(self):
+ await self.cog._async_init()
+ self.assertEqual(self.cog._verified_role.id, Roles.verified)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_channels(self):
"""Got channels from bot."""
- await self.cog._get_instance_vars()
- self.bot.get_channel.called_once_with(Channels.mod_alerts)
- self.bot.get_channel.called_once_with(Channels.mod_log)
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)
- @mock.patch("bot.exts.moderation.silence.SilenceNotifier")
- async def test_instance_vars_got_notifier(self, notifier):
+ @autospec(silence, "SilenceNotifier")
+ async def test_async_init_got_notifier(self, notifier):
"""Notifier was started with channel."""
- mod_log = MockTextChannel()
- self.bot.get_channel.side_effect = (None, mod_log)
- await self.cog._get_instance_vars()
- notifier.assert_called_once_with(mod_log)
- self.bot.get_channel.side_effect = None
-
- async def test_silence_sent_correct_discord_message(self):
- """Check if proper message was sent when called with duration in channel with previous state."""
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))
+ self.assertEqual(self.cog.notifier, notifier.return_value)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_rescheduled(self):
+ """`_reschedule_` coroutine was awaited."""
+ self.cog._reschedule = mock.create_autospec(self.cog._reschedule)
+ await self.cog._async_init()
+ self.cog._reschedule.assert_awaited_once_with()
+
+ def test_cog_unload_cancelled_tasks(self):
+ """The init task was cancelled."""
+ self.cog._init_task = asyncio.Future()
+ self.cog.cog_unload()
+
+ # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda.
+ self.assertTrue(self.cog._init_task.cancelled())
+
+ @autospec("discord.ext.commands", "has_any_role")
+ @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3))
+ async def test_cog_check(self, role_check):
+ """Role check was called with `MODERATION_ROLES`"""
+ ctx = MockContext()
+ role_check.return_value.predicate = mock.AsyncMock()
+
+ await self.cog.cog_check(ctx)
+ role_check.assert_called_once_with(*(1, 2, 3))
+ role_check.return_value.predicate.assert_awaited_once_with(ctx)
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class RescheduleTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the rescheduling of cached unsilences."""
+
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper)
+
+ with mock.patch.object(self.cog, "_reschedule", autospec=True):
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ async def test_skipped_missing_channel(self):
+ """Did nothing because the channel couldn't be retrieved."""
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)]
+ self.bot.get_channel.return_value = None
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_added_permanent_to_notifier(self):
+ """Permanently silenced channels were added to the notifier."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)]
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_any_call(channels[0])
+ self.cog.notifier.add_channel.assert_any_call(channels[1])
+
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_unsilenced_expired(self):
+ """Unsilenced expired silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)]
+
+ await self.cog._reschedule()
+
+ self.cog._unsilence_wrapper.assert_any_call(channels[0])
+ self.cog._unsilence_wrapper.assert_any_call(channels[1])
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ @mock.patch.object(silence, "datetime", new=PatchedDatetime)
+ async def test_rescheduled_active(self):
+ """Rescheduled active silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)]
+ silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc)
+
+ self.cog._unsilence_wrapper = mock.MagicMock()
+ unsilence_return = self.cog._unsilence_wrapper.return_value
+
+ await self.cog._reschedule()
+
+ # Yuck.
+ calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)]
+ self.cog.scheduler.schedule_later.assert_has_calls(calls)
+
+ unsilence_calls = [mock.call(channel) for channel in channels]
+ self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls)
+
+ self.cog.notifier.add_channel.assert_not_called()
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the silence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ # Avoid unawaited coroutine warnings.
+ self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close()
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
test_cases = (
- (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,),
- (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,),
- (5, f"{Emojis.cross_mark} current channel is already silenced.", False,),
+ (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),
+ (None, silence.MSG_SILENCE_PERMANENT, True,),
+ (5, silence.MSG_SILENCE_FAIL, False,),
)
- for duration, result_message, _silence_patch_return in test_cases:
- with self.subTest(
- silence_duration=duration,
- result_message=result_message,
- starting_unsilenced_state=_silence_patch_return
- ):
- with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
- await self.cog.silence.callback(self.cog, self.ctx, duration)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_unsilence_sent_correct_discord_message(self):
- """Check if proper message was sent when unsilencing channel."""
- test_cases = (
- (True, f"{Emojis.check_mark} unsilenced current channel."),
- (False, f"{Emojis.cross_mark} current channel was not silenced.")
+ for duration, message, was_silenced in test_cases:
+ ctx = MockContext()
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced):
+ with self.subTest(was_silenced=was_silenced, message=message, duration=duration):
+ await self.cog.silence.callback(self.cog, ctx, duration)
+ ctx.send.assert_called_once_with(message)
+
+ async def test_skipped_already_silenced(self):
+ """Permissions were not set and `False` was returned for an already silenced channel."""
+ subtests = (
+ (False, PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (True, PermissionOverwrite(send_messages=True, add_reactions=True)),
+ (True, PermissionOverwrite(send_messages=False, add_reactions=False)),
)
- for _unsilence_patch_return, result_message in test_cases:
- with self.subTest(
- starting_silenced_state=_unsilence_patch_return,
- result_message=result_message
- ):
- with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
- await self.cog.unsilence.callback(self.cog, self.ctx)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_silence_private_for_false(self):
- """Permissions are not set and `False` is returned in an already silenced channel."""
- perm_overwrite = Mock(send_messages=False)
- channel = Mock(overwrites_for=Mock(return_value=perm_overwrite))
-
- self.assertFalse(await self.cog._silence(channel, True, None))
- channel.set_permissions.assert_not_called()
- async def test_silence_private_silenced_channel(self):
- """Channel had `send_message` permissions revoked."""
- channel = MockTextChannel()
- self.assertTrue(await self.cog._silence(channel, False, None))
- channel.set_permissions.assert_called_once()
- self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages'])
+ for contains, overwrite in subtests:
+ with self.subTest(contains=contains, overwrite=overwrite):
+ self.cog.scheduler.__contains__.return_value = contains
+ channel = MockTextChannel()
+ channel.overwrites_for.return_value = overwrite
+
+ self.assertFalse(await self.cog._set_silence_overwrites(channel))
+ channel.set_permissions.assert_not_called()
+
+ async def test_silenced_channel(self):
+ """Channel had `send_message` and `add_reactions` permissions revoked for verified role."""
+ self.assertTrue(await self.cog._set_silence_overwrites(self.channel))
+ self.assertFalse(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite
+ )
- async def test_silence_private_preserves_permissions(self):
- """Previous permissions were preserved when channel was silenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite()
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._silence(channel, False, None)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- async def test_silence_private_notifier(self):
- """Channel should be added to notifier with `persistent` set to `True`, and the other way around."""
- channel = MockTextChannel()
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=True):
- await self.cog._silence(channel, True, None)
- self.cog.notifier.add_channel.assert_called_once()
-
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=False):
- await self.cog._silence(channel, False, None)
- self.cog.notifier.add_channel.assert_not_called()
-
- async def test_silence_private_added_muted_channel(self):
- """Channel was added to `muted_channels` on silence."""
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed."""
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._set_silence_overwrites(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
+ async def test_temp_not_added_to_notifier(self):
+ """Channel was not added to notifier if a duration was set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_indefinite_added_to_notifier(self):
+ """Channel was added to notifier if a duration was not set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), None)
+ self.cog.notifier.add_channel.assert_called_once()
+
+ async def test_silenced_not_added_to_notifier(self):
+ """Channel was not added to the notifier if it was already silenced."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_cached_previous_overwrites(self):
+ """Channel's previous overwrites were cached."""
+ overwrite_json = '{"send_messages": true, "add_reactions": false}'
+ await self.cog._set_silence_overwrites(self.channel)
+ self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json)
+
+ @autospec(silence, "datetime")
+ async def test_cached_unsilence_time(self, datetime_mock):
+ """The UTC POSIX timestamp for the unsilence was cached."""
+ now_timestamp = 100
+ duration = 15
+ timestamp = now_timestamp + duration * 60
+ datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc)
+
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, duration)
+
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp)
+ datetime_mock.now.assert_called_once_with(tz=timezone.utc) # Ensure it's using an aware dt.
+
+ async def test_cached_indefinite_time(self):
+ """A value of -1 was cached for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
+
+ async def test_scheduled_task(self):
+ """An unsilence task was scheduled."""
+ ctx = MockContext(channel=self.channel, invoke=mock.MagicMock())
+
+ await self.cog.silence.callback(self.cog, ctx, 5)
+
+ args = (300, ctx.channel.id, ctx.invoke.return_value)
+ self.cog.scheduler.schedule_later.assert_called_once_with(*args)
+ ctx.invoke.assert_called_once_with(self.cog.unsilence)
+
+ async def test_permanent_not_scheduled(self):
+ """A task was not scheduled for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+
+@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)
+class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the unsilence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot(get_channel=lambda _: MockTextChannel())
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
+ self.cog.previous_overwrites = overwrites_cache
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.cog.scheduler.__contains__.return_value = True
+ overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
+ unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True)
+ test_cases = (
+ (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),
+ )
+ for was_unsilenced, message, overwrite in test_cases:
+ ctx = MockContext()
+ with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite):
+ with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced):
+ ctx.channel.overwrites_for.return_value = overwrite
+ await self.cog.unsilence.callback(self.cog, ctx)
+ ctx.channel.send.assert_called_once_with(message)
+
+ async def test_skipped_already_unsilenced(self):
+ """Permissions were not set and `False` was returned for an already unsilenced channel."""
+ self.cog.scheduler.__contains__.return_value = False
+ self.cog.previous_overwrites.get.return_value = None
channel = MockTextChannel()
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._silence(channel, False, None)
- muted_channels.add.assert_called_once_with(channel)
- async def test_unsilence_private_for_false(self):
- """Permissions are not set and `False` is returned in an unsilenced channel."""
- channel = Mock()
self.assertFalse(await self.cog._unsilence(channel))
channel.set_permissions.assert_not_called()
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_unsilenced_channel(self, _):
- """Channel had `send_message` permissions restored"""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- self.assertTrue(await self.cog._unsilence(channel))
- channel.set_permissions.assert_called_once()
- self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages'])
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_notifier(self, notifier):
- """Channel was removed from `notifier` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- await self.cog._unsilence(channel)
- notifier.remove_channel.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_muted_channel(self, _):
- """Channel was removed from `muted_channels` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._unsilence(channel)
- muted_channels.discard.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_preserves_permissions(self, _):
- """Previous permissions were preserved when channel was unsilenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite(send_messages=False)
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._unsilence(channel)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- @mock.patch.object(Silence, "_mod_alerts_channel", create=True)
- def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):
- """Task for sending an alert was created with present `muted_channels`."""
- with mock.patch.object(self.cog, "muted_channels"):
- self.cog.cog_unload()
- alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")
- asyncio_mock.create_task.assert_called_once_with(alert_channel.send())
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- def test_cog_unload_skips_task_start(self, asyncio_mock):
- """No task created with no channels."""
- self.cog.cog_unload()
- asyncio_mock.create_task.assert_not_called()
+ async def test_restored_overwrites(self):
+ """Channel's `send_message` and `add_reactions` overwrites were restored."""
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
- @mock.patch("discord.ext.commands.has_any_role")
- @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
- async def test_cog_check(self, role_check):
- """Role check is called with `MODERATION_ROLES`"""
- role_check.return_value.predicate = mock.AsyncMock()
- await self.cog.cog_check(self.ctx)
- role_check.assert_called_once_with(*(1, 2, 3))
- role_check.return_value.predicate.assert_awaited_once_with(self.ctx)
+ # Recall that these values are determined by the fixture.
+ self.assertTrue(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+
+ async def test_cache_miss_used_default_overwrites(self):
+ """Both overwrites were set to None due previous values not being found in the cache."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
+
+ self.assertIsNone(self.overwrite.send_messages)
+ self.assertIsNone(self.overwrite.add_reactions)
+
+ async def test_cache_miss_sent_mod_alert(self):
+ """A message was sent to the mod alerts channel."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.cog._mod_alerts_channel.send.assert_awaited_once()
+
+ async def test_removed_notifier(self):
+ """Channel was removed from `notifier`."""
+ await self.cog._unsilence(self.channel)
+ self.cog.notifier.remove_channel.assert_called_once_with(self.channel)
+
+ async def test_deleted_cached_overwrite(self):
+ """Channel was deleted from the overwrites cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_deleted_cached_time(self):
+ """Channel was deleted from the timestamp cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_cancelled_task(self):
+ """The scheduled unsilence task should be cancelled."""
+ await self.cog._unsilence(self.channel)
+ self.cog.scheduler.cancel.assert_called_once_with(self.channel.id)
+
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed, including cache misses."""
+ for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):
+ with self.subTest(overwrite_json=overwrite_json):
+ self.cog.previous_overwrites.get.return_value = overwrite_json
+
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._unsilence(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove these keys because they were modified by the unsilence.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 40b2202aa..9a42d0610 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -52,6 +52,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'),
+ ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```',
+ 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'),
+ ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```',
+ 'print("How\'s it going?")', 'code block preceded by inline code'),
+ ('`print("Hello world!")`\ntext\n`print("Hello world!")`',
+ 'print("Hello world!")', 'one inline code block of two')
)
for case, expected, testname in cases:
with self.subTest(msg=f'Extract code from {testname}.'):
@@ -154,7 +161,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.send_eval = AsyncMock(return_value=response)
self.cog.continue_eval = AsyncMock(return_value=None)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_once_with(ctx, response)
@@ -168,7 +175,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.continue_eval = AsyncMock()
self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
self.cog.continue_eval.assert_called_with(ctx, response)
@@ -180,7 +187,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.author.mention = '@LemonLemonishBeard#0042'
ctx.send = AsyncMock()
self.cog.jobs = (42,)
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ await self.cog.eval_command(self.cog, ctx=ctx, code='MyAwesomeCode')
ctx.send.assert_called_once_with(
"@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
)
@@ -188,8 +195,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_eval_command_call_help(self):
"""Test if the eval command call the help command if no code is provided."""
ctx = MockContext(command="sentinel")
- await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
- ctx.send_help.assert_called_once_with("sentinel")
+ await self.cog.eval_command(self.cog, ctx=ctx, code='')
+ ctx.send_help.assert_called_once_with(ctx.command)
async def test_send_eval(self):
"""Test the send_eval function."""
@@ -290,7 +297,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
)
)
ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
- ctx.message.clear_reactions.assert_called_once()
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
response.delete.assert_called_once()
async def test_continue_eval_does_not_continue(self):
@@ -299,7 +306,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
actual = await self.cog.continue_eval(ctx, MockMessage())
self.assertEqual(actual, None)
- ctx.message.clear_reactions.assert_called_once()
+ ctx.message.clear_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
async def test_get_code(self):
"""Should return 1st arg (or None) if eval cmd in message, otherwise return full content."""
diff --git a/tests/bot/patches/__init__.py b/tests/bot/patches/__init__.py
deleted file mode 100644
index e69de29bb..000000000
--- a/tests/bot/patches/__init__.py
+++ /dev/null
diff --git a/tests/helpers.py b/tests/helpers.py
index e47fdf28f..870f66197 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -5,7 +5,7 @@ import itertools
import logging
import unittest.mock
from asyncio import AbstractEventLoop
-from typing import Callable, Iterable, Optional
+from typing import Iterable, Optional
import discord
from aiohttp import ClientSession
@@ -14,6 +14,7 @@ from discord.ext.commands import Context
from bot.api import APIClient
from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
+from tests._autospec import autospec # noqa: F401 other modules import it via this module
for logger in logging.Logger.manager.loggerDict.values():
@@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def autospec(target, *attributes: str, **kwargs) -> Callable:
- """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
- # Caller's kwargs should take priority and overwrite the defaults.
- kwargs = {'spec_set': True, 'autospec': True, **kwargs}
-
- # Import the target if it's a string.
- # This is to support both object and string targets like patch.multiple.
- if type(target) is str:
- target = unittest.mock._importer(target)
-
- def decorator(func):
- for attribute in attributes:
- patcher = unittest.mock.patch.object(target, attribute, **kwargs)
- func = patcher(func)
- return func
- return decorator
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.