aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2022-02-22 18:17:32 -0800
committerGravatar MarkKoz <[email protected]>2022-02-22 18:17:32 -0800
commit64ca00f0d306ab9d8bfd78b9af9e99d33283cb03 (patch)
treeaf7231269aca52153f7f19df8a2d9b96bb1fb6b5 /tests
parentClarify that a resent infraction DM is not a new infraction. (diff)
parentMerge pull request #2078 from python-discord/chris/fix/help-channel-errors (diff)
Merge main into feat/mod/1664/resend-infraction
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py3
-rw-r--r--tests/base.py5
-rw-r--r--tests/bot/exts/backend/sync/test_base.py1
-rw-r--r--tests/bot/exts/backend/sync/test_cog.py7
-rw-r--r--tests/bot/exts/backend/sync/test_users.py7
-rw-r--r--tests/bot/exts/backend/test_error_handler.py64
-rw-r--r--tests/bot/exts/events/__init__.py0
-rw-r--r--tests/bot/exts/events/test_code_jams.py (renamed from tests/bot/exts/utils/test_jams.py)68
-rw-r--r--tests/bot/exts/filters/test_filtering.py40
-rw-r--r--tests/bot/exts/filters/test_token_remover.py17
-rw-r--r--tests/bot/exts/info/test_information.py104
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py189
-rw-r--r--tests/bot/exts/moderation/infraction/test_utils.py64
-rw-r--r--tests/bot/exts/moderation/test_clean.py104
-rw-r--r--tests/bot/exts/moderation/test_incidents.py109
-rw-r--r--tests/bot/exts/moderation/test_modlog.py2
-rw-r--r--tests/bot/exts/moderation/test_silence.py73
-rw-r--r--tests/bot/exts/utils/test_snekbox.py24
-rw-r--r--tests/bot/rules/test_mentions.py26
-rw-r--r--tests/bot/test_converters.py111
-rw-r--r--tests/bot/utils/test_checks.py1
-rw-r--r--tests/bot/utils/test_message_cache.py214
-rw-r--r--tests/bot/utils/test_time.py121
-rw-r--r--tests/helpers.py34
-rw-r--r--tests/test_base.py11
25 files changed, 972 insertions, 427 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index 2228110ad..c2b9d12dc 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,6 @@
import logging
+from bot.log import get_logger
-log = logging.getLogger()
+log = get_logger()
log.setLevel(logging.CRITICAL)
diff --git a/tests/base.py b/tests/base.py
index d99b9ac31..5e304ea9d 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -6,6 +6,7 @@ from typing import Dict
import discord
from discord.ext import commands
+from bot.log import get_logger
from tests import helpers
@@ -42,7 +43,7 @@ class LoggingTestsMixin:
manager when we're testing under the assumption that no log records will be emitted.
"""
if not isinstance(logger, logging.Logger):
- logger = logging.getLogger(logger)
+ logger = get_logger(logger)
if level:
level = logging._nameToLevel.get(level, level)
@@ -102,4 +103,4 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(commands.MissingPermissions) as cm:
await cmd.can_run(ctx)
- self.assertCountEqual(permissions.keys(), cm.exception.missing_perms)
+ self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions)
diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py
index 3ad9db9c3..9dc46005b 100644
--- a/tests/bot/exts/backend/sync/test_base.py
+++ b/tests/bot/exts/backend/sync/test_base.py
@@ -1,7 +1,6 @@
import unittest
from unittest import mock
-
from bot.api import ResponseCodeError
from bot.exts.backend.sync._syncers import Syncer
from tests import helpers
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py
index 22a07313e..fdd0ab74a 100644
--- a/tests/bot/exts/backend/sync/test_cog.py
+++ b/tests/bot/exts/backend/sync/test_cog.py
@@ -60,13 +60,13 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
class SyncCogTests(SyncCogTestCase):
"""Tests for the Sync cog."""
+ @mock.patch("bot.utils.scheduling.create_task")
@mock.patch.object(Sync, "sync_guild", new_callable=mock.MagicMock)
- def test_sync_cog_init(self, sync_guild):
+ def test_sync_cog_init(self, sync_guild, create_task):
"""Should instantiate syncers and run a sync for the guild."""
# Reset because a Sync cog was already instantiated in setUp.
self.RoleSyncer.reset_mock()
self.UserSyncer.reset_mock()
- self.bot.loop.create_task = mock.MagicMock()
mock_sync_guild_coro = mock.MagicMock()
sync_guild.return_value = mock_sync_guild_coro
@@ -74,7 +74,8 @@ class SyncCogTests(SyncCogTestCase):
Sync(self.bot)
sync_guild.assert_called_once_with()
- self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
+ create_task.assert_called_once()
+ self.assertEqual(create_task.call_args.args[0], mock_sync_guild_coro)
async def test_sync_cog_sync_guild(self):
"""Roles and users should be synced only if a guild is successfully retrieved."""
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index 27932be95..2fc97af2d 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,6 +1,8 @@
import unittest
from unittest import mock
+from discord.errors import NotFound
+
from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -10,7 +12,7 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("roles", [666])
+ kwargs.setdefault("roles", [helpers.MockRole(id=666)])
kwargs.setdefault("in_guild", True)
return kwargs
@@ -134,6 +136,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.get_mock_member(fake_user()),
None
]
+ guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [{"id": 63, "in_guild": False}], None)
@@ -158,6 +161,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.get_mock_member(updated_user),
None
]
+ guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
@@ -177,6 +181,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.get_mock_member(fake_user()),
None
]
+ guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [], None)
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 2b0549b98..35fa0ee59 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -107,7 +107,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
"""Should send error with `ctx.send` when error is `CommandOnCooldown`."""
self.ctx.reset_mock()
cog = ErrorHandler(self.bot)
- error = errors.CommandOnCooldown(10, 9)
+ error = errors.CommandOnCooldown(10, 9, type=None)
self.assertIsNone(await cog.on_command_error(self.ctx, error))
self.ctx.send.assert_awaited_once_with(error)
@@ -337,14 +337,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
async def test_try_get_tag_get_command(self):
"""Should call `Bot.get_command` with `tags get` argument."""
self.bot.get_command.reset_mock()
- self.ctx.invoked_with = "foo"
await self.cog.try_get_tag(self.ctx)
self.bot.get_command.assert_called_once_with("tags get")
async def test_try_get_tag_invoked_from_error_handler(self):
"""`self.ctx` should have `invoked_from_error_handler` `True`."""
self.ctx.invoked_from_error_handler = False
- self.ctx.invoked_with = "foo"
await self.cog.try_get_tag(self.ctx)
self.assertTrue(self.ctx.invoked_from_error_handler)
@@ -359,38 +357,12 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
err = errors.CommandError()
self.tag.get_command.can_run = AsyncMock(side_effect=err)
self.cog.on_command_error = AsyncMock()
- self.ctx.invoked_with = "foo"
self.assertIsNone(await self.cog.try_get_tag(self.ctx))
self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
- @patch("bot.exts.backend.error_handler.TagNameConverter")
- async def test_try_get_tag_convert_success(self, tag_converter):
- """Converting tag should successful."""
- self.ctx.invoked_with = "foo"
- tag_converter.convert = AsyncMock(return_value="foo")
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- tag_converter.convert.assert_awaited_once_with(self.ctx, "foo")
- self.ctx.invoke.assert_awaited_once()
-
- @patch("bot.exts.backend.error_handler.TagNameConverter")
- async def test_try_get_tag_convert_fail(self, tag_converter):
- """Converting tag should raise `BadArgument`."""
- self.ctx.reset_mock()
- self.ctx.invoked_with = "bar"
- tag_converter.convert = AsyncMock(side_effect=errors.BadArgument())
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- self.ctx.invoke.assert_not_awaited()
-
- async def test_try_get_tag_ctx_invoke(self):
- """Should call `ctx.invoke` with proper args/kwargs."""
- self.ctx.reset_mock()
- self.ctx.invoked_with = "foo"
- self.assertIsNone(await self.cog.try_get_tag(self.ctx))
- self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo")
-
async def test_dont_call_suggestion_tag_sent(self):
"""Should never call command suggestion if tag is already sent."""
- self.ctx.invoked_with = "foo"
+ self.ctx.message = MagicMock(content="foo")
self.ctx.invoke = AsyncMock(return_value=True)
self.cog.send_command_suggestion = AsyncMock()
@@ -572,38 +544,6 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
push_scope_mock.set_extra.has_calls(set_extra_calls)
-class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
- """Other `ErrorHandler` tests."""
-
- def setUp(self):
- self.bot = MockBot()
- self.ctx = MockContext()
-
- async def test_get_help_command_command_specified(self):
- """Should return coroutine of help command of specified command."""
- self.ctx.command = "foo"
- result = ErrorHandler.get_help_command(self.ctx)
- expected = self.ctx.send_help("foo")
- self.assertEqual(result.__qualname__, expected.__qualname__)
- self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
-
- # Await coroutines to avoid warnings
- await result
- await expected
-
- async def test_get_help_command_no_command_specified(self):
- """Should return coroutine of help command."""
- self.ctx.command = None
- result = ErrorHandler.get_help_command(self.ctx)
- expected = self.ctx.send_help()
- self.assertEqual(result.__qualname__, expected.__qualname__)
- self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
-
- # Await coroutines to avoid warnings
- await result
- await expected
-
-
class ErrorHandlerSetupTests(unittest.TestCase):
"""Tests for `ErrorHandler` `setup` function."""
diff --git a/tests/bot/exts/events/__init__.py b/tests/bot/exts/events/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/exts/events/__init__.py
diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/events/test_code_jams.py
index 368a15476..0856546af 100644
--- a/tests/bot/exts/utils/test_jams.py
+++ b/tests/bot/exts/events/test_code_jams.py
@@ -1,14 +1,15 @@
import unittest
-from unittest.mock import AsyncMock, MagicMock, create_autospec
+from unittest.mock import AsyncMock, MagicMock, create_autospec, patch
from discord import CategoryChannel
from discord.ext.commands import BadArgument
from bot.constants import Roles
-from bot.exts.utils import jams
+from bot.exts.events import code_jams
+from bot.exts.events.code_jams import _channels, _cog
from tests.helpers import (
- MockAttachment, MockBot, MockCategoryChannel, MockContext,
- MockGuild, MockMember, MockRole, MockTextChannel
+ MockAttachment, MockBot, MockCategoryChannel, MockContext, MockGuild, MockMember, MockRole, MockTextChannel,
+ autospec
)
TEST_CSV = b"""\
@@ -40,7 +41,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.command_user = MockMember([self.admin_role])
self.guild = MockGuild([self.admin_role])
self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild)
- self.cog = jams.CodeJams(self.bot)
+ self.cog = _cog.CodeJams(self.bot)
async def test_message_without_attachments(self):
"""If no link or attachments are provided, commands.BadArgument should be raised."""
@@ -49,7 +50,9 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(BadArgument):
await self.cog.create(self.cog, self.ctx, None)
- async def test_result_sending(self):
+ @patch.object(_channels, "create_team_channel")
+ @patch.object(_channels, "create_team_leader_channel")
+ async def test_result_sending(self, create_leader_channel, create_team_channel):
"""Should call `ctx.send` when everything goes right."""
self.ctx.message.attachments = [MockAttachment()]
self.ctx.message.attachments[0].read = AsyncMock()
@@ -61,14 +64,12 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.ctx.guild.create_role = AsyncMock()
self.ctx.guild.create_role.return_value = team_leaders
- self.cog.create_team_channel = AsyncMock()
- self.cog.create_team_leader_channel = AsyncMock()
self.cog.add_roles = AsyncMock()
await self.cog.create(self.cog, self.ctx, None)
- self.cog.create_team_channel.assert_awaited()
- self.cog.create_team_leader_channel.assert_awaited_once_with(
+ create_team_channel.assert_awaited()
+ create_leader_channel.assert_awaited_once_with(
self.ctx.guild, team_leaders
)
self.ctx.send.assert_awaited_once()
@@ -81,25 +82,24 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.ctx.send.assert_awaited_once()
- async def test_category_doesnt_exist(self):
+ @patch.object(_channels, "_send_status_update")
+ async def test_category_doesnt_exist(self, update):
"""Should create a new code jam category."""
subtests = (
[],
- [get_mock_category(jams.MAX_CHANNELS, jams.CATEGORY_NAME)],
- [get_mock_category(jams.MAX_CHANNELS - 2, "other")],
+ [get_mock_category(_channels.MAX_CHANNELS, _channels.CATEGORY_NAME)],
+ [get_mock_category(_channels.MAX_CHANNELS - 2, "other")],
)
- self.cog.send_status_update = AsyncMock()
-
for categories in subtests:
- self.cog.send_status_update.reset_mock()
+ update.reset_mock()
self.guild.reset_mock()
self.guild.categories = categories
with self.subTest(categories=categories):
- actual_category = await self.cog.get_category(self.guild)
+ actual_category = await _channels._get_category(self.guild)
- self.cog.send_status_update.assert_called_once()
+ update.assert_called_once()
self.guild.create_category_channel.assert_awaited_once()
category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"]
@@ -109,45 +109,41 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
async def test_category_channel_exist(self):
"""Should not try to create category channel."""
- expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME)
+ expected_category = get_mock_category(_channels.MAX_CHANNELS - 2, _channels.CATEGORY_NAME)
self.guild.categories = [
- get_mock_category(jams.MAX_CHANNELS - 2, "other"),
+ get_mock_category(_channels.MAX_CHANNELS - 2, "other"),
expected_category,
- get_mock_category(0, jams.CATEGORY_NAME),
+ get_mock_category(0, _channels.CATEGORY_NAME),
]
- actual_category = await self.cog.get_category(self.guild)
+ actual_category = await _channels._get_category(self.guild)
self.assertEqual(expected_category, actual_category)
async def test_channel_overwrites(self):
"""Should have correct permission overwrites for users and roles."""
leader = (MockMember(), True)
members = [leader] + [(MockMember(), False) for _ in range(4)]
- overwrites = self.cog.get_overwrites(members, self.guild)
+ overwrites = _channels._get_overwrites(members, self.guild)
for member, _ in members:
self.assertTrue(overwrites[member].read_messages)
- async def test_team_channels_creation(self):
+ @patch.object(_channels, "_get_overwrites")
+ @patch.object(_channels, "_get_category")
+ @autospec(_channels, "_add_team_leader_roles", pass_mocks=False)
+ async def test_team_channels_creation(self, get_category, get_overwrites):
"""Should create a text channel for a team."""
team_leaders = MockRole()
members = [(MockMember(), True)] + [(MockMember(), False) for _ in range(5)]
category = MockCategoryChannel()
category.create_text_channel = AsyncMock()
- self.cog.get_overwrites = MagicMock()
- self.cog.get_category = AsyncMock()
- self.cog.get_category.return_value = category
- self.cog.add_team_leader_roles = AsyncMock()
-
- await self.cog.create_team_channel(self.guild, "my-team", members, team_leaders)
- self.cog.add_team_leader_roles.assert_awaited_once_with(members, team_leaders)
- self.cog.get_overwrites.assert_called_once_with(members, self.guild)
- self.cog.get_category.assert_awaited_once_with(self.guild)
+ get_category.return_value = category
+ await _channels.create_team_channel(self.guild, "my-team", members, team_leaders)
category.create_text_channel.assert_awaited_once_with(
"my-team",
- overwrites=self.cog.get_overwrites.return_value
+ overwrites=get_overwrites.return_value
)
async def test_jam_roles_adding(self):
@@ -156,7 +152,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
leader = MockMember()
members = [(leader, True)] + [(MockMember(), False) for _ in range(4)]
- await self.cog.add_team_leader_roles(members, leader_role)
+ await _channels._add_team_leader_roles(members, leader_role)
leader.add_roles.assert_awaited_once_with(leader_role)
for member, is_leader in members:
@@ -170,5 +166,5 @@ class CodeJamSetup(unittest.TestCase):
def test_setup(self):
"""Should call `bot.add_cog`."""
bot = MockBot()
- jams.setup(bot)
+ code_jams.setup(bot)
bot.add_cog.assert_called_once()
diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py
new file mode 100644
index 000000000..8ae59c1f1
--- /dev/null
+++ b/tests/bot/exts/filters/test_filtering.py
@@ -0,0 +1,40 @@
+import unittest
+from unittest.mock import patch
+
+from bot.exts.filters import filtering
+from tests.helpers import MockBot, autospec
+
+
+class FilteringCogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the `Filtering` cog."""
+
+ def setUp(self):
+ """Instantiate the bot and cog."""
+ self.bot = MockBot()
+ with patch("bot.utils.scheduling.create_task", new=lambda task, **_: task.close()):
+ self.cog = filtering.Filtering(self.bot)
+
+ @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"])
+ async def test_token_filter(self):
+ """Ensure that a filter token is correctly detected in a message."""
+ messages = {
+ "": False,
+ "no matches": False,
+ "TOKEN": True,
+
+ # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6
+ "https://google.com TOKEN": True,
+ "https://google.com something else": False,
+ }
+
+ for message, match in messages.items():
+ with self.subTest(input=message, match=match):
+ result, _ = await self.cog._has_watch_regex_match(message)
+
+ self.assertEqual(
+ match,
+ bool(result),
+ msg=f"Hit was {'expected' if match else 'not expected'} for this input."
+ )
+ if result:
+ self.assertEqual("TOKEN", result.group())
diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py
index 51feae9cb..4db27269a 100644
--- a/tests/bot/exts/filters/test_token_remover.py
+++ b/tests/bot/exts/filters/test_token_remover.py
@@ -26,7 +26,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
self.msg.guild.get_member.return_value.bot = False
self.msg.guild.get_member.return_value.__str__.return_value = "Woody"
self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
- self.msg.author.avatar_url_as.return_value = "picture-lemon.png"
+ self.msg.author.display_avatar.url = "picture-lemon.png"
def test_extract_user_id_valid(self):
"""Should consider user IDs valid if they decode into an integer ID."""
@@ -295,20 +295,21 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
)
@autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE")
- def test_format_userid_log_message_unknown(self, unknown_user_log_message):
+ async def test_format_userid_log_message_unknown(self, unknown_user_log_message,):
"""Should correctly format the user ID portion when the actual user it belongs to is unknown."""
token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
unknown_user_log_message.format.return_value = " Partner"
msg = MockMessage(id=555, content="hello world")
msg.guild.get_member.return_value = None
+ msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
- return_value = TokenRemover.format_userid_log_message(msg, token)
+ return_value = await TokenRemover.format_userid_log_message(msg, token)
self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332)
@autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- def test_format_userid_log_message_bot(self, known_user_log_message):
+ async def test_format_userid_log_message_bot(self, known_user_log_message):
"""Should correctly format the user ID portion when the ID belongs to a known bot."""
token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
known_user_log_message.format.return_value = " Partner"
@@ -316,7 +317,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
msg.guild.get_member.return_value.__str__.return_value = "Sam"
msg.guild.get_member.return_value.bot = True
- return_value = TokenRemover.format_userid_log_message(msg, token)
+ return_value = await TokenRemover.format_userid_log_message(msg, token)
self.assertEqual(return_value, (known_user_log_message.format.return_value, True))
@@ -327,12 +328,12 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
)
@autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- def test_format_log_message_user_token_user(self, user_token_message):
+ async def test_format_log_message_user_token_user(self, user_token_message):
"""Should correctly format the user ID portion when the ID belongs to a known user."""
token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
user_token_message.format.return_value = "Partner"
- return_value = TokenRemover.format_userid_log_message(self.msg, token)
+ return_value = await TokenRemover.format_userid_log_message(self.msg, token)
self.assertEqual(return_value, (user_token_message.format.return_value, True))
user_token_message.format.assert_called_once_with(
@@ -375,7 +376,7 @@ class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
colour=Colour(constants.Colours.soft_red),
title="Token removed!",
text=log_msg + "\n" + userid_log_message,
- thumbnail=self.msg.author.avatar_url_as.return_value,
+ thumbnail=self.msg.author.display_avatar.url,
channel_id=constants.Channels.mod_alerts,
ping_everyone=True,
)
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index 770660fe3..d896b7652 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -1,6 +1,7 @@
import textwrap
import unittest
import unittest.mock
+from datetime import datetime
import discord
@@ -42,7 +43,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase):
embed = kwargs.pop('embed')
self.assertEqual(embed.title, "Role information (Total 1 role)")
- self.assertEqual(embed.colour, discord.Colour.blurple())
+ self.assertEqual(embed.colour, discord.Colour.og_blurple())
self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
async def test_role_info_command(self):
@@ -50,7 +51,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase):
dummy_role = helpers.MockRole(
name="Dummy",
id=112233445566778899,
- colour=discord.Colour.blurple(),
+ colour=discord.Colour.og_blurple(),
position=10,
members=[self.ctx.author],
permissions=discord.Permissions(0)
@@ -80,7 +81,7 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase):
admin_embed = admin_kwargs["embed"]
self.assertEqual(dummy_embed.title, "Dummy info")
- self.assertEqual(dummy_embed.colour, discord.Colour.blurple())
+ self.assertEqual(dummy_embed.colour, discord.Colour.og_blurple())
self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id))
self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}")
@@ -262,7 +263,6 @@ class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):
await self._method_subtests(self.cog.user_nomination_counts, test_values, header)
[email protected]("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))
@unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50])
class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the creation of the `!user` embed."""
@@ -277,6 +277,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -285,8 +289,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
user.nick = None
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
+ user.created_at = user.joined_at = datetime.utcnow()
- embed = await self.cog.create_user_embed(ctx, user)
+ embed = await self.cog.create_user_embed(ctx, user, False)
self.assertEqual(embed.title, "Mr. Hemlock")
@@ -294,6 +299,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -302,8 +311,9 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
user.nick = "Cat lover"
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
+ user.created_at = user.joined_at = datetime.utcnow()
- embed = await self.cog.create_user_embed(ctx, user)
+ embed = await self.cog.create_user_embed(ctx, user, False)
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
@@ -311,6 +321,10 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -318,14 +332,19 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
# A `MockMember` has the @Everyone role by default; we add the Admins to that.
user = helpers.MockMember(roles=[admins_role], colour=100)
+ user.created_at = user.joined_at = datetime.utcnow()
- embed = await self.cog.create_user_embed(ctx, user)
+ embed = await self.cog.create_user_embed(ctx, user, False)
self.assertIn("&Admins", embed.fields[1].value)
self.assertNotIn("&Everyone", embed.fields[1].value)
@unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
@unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_expanded_information_in_moderation_channels(
self,
nomination_counts,
@@ -340,14 +359,15 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
nomination_counts.return_value = ("Nominations", "nomination info")
user = helpers.MockMember(id=314, roles=[moderators_role], colour=100)
- embed = await self.cog.create_user_embed(ctx, user)
+ user.created_at = user.joined_at = datetime.utcfromtimestamp(1)
+ embed = await self.cog.create_user_embed(ctx, user, False)
infraction_counts.assert_called_once_with(user)
nomination_counts.assert_called_once_with(user)
self.assertEqual(
textwrap.dedent(f"""
- Created: {"1 year ago"}
+ Created: {"<t:1:R>"}
Profile: {user.mention}
ID: {user.id}
""").strip(),
@@ -356,7 +376,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
textwrap.dedent(f"""
- Joined: {"1 year ago"}
+ Joined: {"<t:1:R>"}
Verified: {"True"}
Roles: &Moderators
""").strip(),
@@ -364,22 +384,29 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
)
@unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
- async def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
+ @unittest.mock.patch(f"{COG_PATH}.user_messages", new_callable=unittest.mock.AsyncMock)
+ async def test_create_user_embed_basic_information_outside_of_moderation_channels(
+ self,
+ user_messages,
+ infraction_counts,
+ ):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
moderators_role = helpers.MockRole(name='Moderators')
infraction_counts.return_value = ("Infractions", "basic infractions info")
+ user_messages.return_value = ("Messages", "user message counts")
user = helpers.MockMember(id=314, roles=[moderators_role], colour=100)
- embed = await self.cog.create_user_embed(ctx, user)
+ user.created_at = user.joined_at = datetime.utcfromtimestamp(1)
+ embed = await self.cog.create_user_embed(ctx, user, False)
infraction_counts.assert_called_once_with(user)
self.assertEqual(
textwrap.dedent(f"""
- Created: {"1 year ago"}
+ Created: {"<t:1:R>"}
Profile: {user.mention}
ID: {user.id}
""").strip(),
@@ -388,21 +415,30 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
textwrap.dedent(f"""
- Joined: {"1 year ago"}
+ Joined: {"<t:1:R>"}
Roles: &Moderators
""").strip(),
embed.fields[1].value
)
self.assertEqual(
- "basic infractions info",
+ "user message counts",
embed.fields[2].value
)
+ self.assertEqual(
+ "basic infractions info",
+ embed.fields[3].value
+ )
+
@unittest.mock.patch(
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -410,7 +446,8 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
moderators_role = helpers.MockRole(name='Moderators')
user = helpers.MockMember(id=314, roles=[moderators_role], colour=100)
- embed = await self.cog.create_user_embed(ctx, user)
+ user.created_at = user.joined_at = datetime.utcnow()
+ embed = await self.cog.create_user_embed(ctx, user, False)
self.assertEqual(embed.colour, discord.Colour(100))
@@ -418,28 +455,37 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
- async def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
- """The embed should be created with a blurple colour if the user has no assigned roles."""
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
+ async def test_create_user_embed_uses_og_blurple_colour_when_user_has_no_roles(self):
+ """The embed should be created with the og blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217, colour=discord.Colour.default())
- embed = await self.cog.create_user_embed(ctx, user)
+ user.created_at = user.joined_at = datetime.utcnow()
+ embed = await self.cog.create_user_embed(ctx, user, False)
- self.assertEqual(embed.colour, discord.Colour.blurple())
+ self.assertEqual(embed.colour, discord.Colour.og_blurple())
@unittest.mock.patch(
f"{COG_PATH}.basic_user_infraction_counts",
new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions"))
)
+ @unittest.mock.patch(
+ f"{COG_PATH}.user_messages",
+ new=unittest.mock.AsyncMock(return_value=("Messsages", "user message count"))
+ )
async def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
user = helpers.MockMember(id=217, colour=0)
- user.avatar_url_as.return_value = "avatar url"
- embed = await self.cog.create_user_embed(ctx, user)
+ user.created_at = user.joined_at = datetime.utcnow()
+ user.display_avatar.url = "avatar url"
+ embed = await self.cog.create_user_embed(ctx, user, False)
- user.avatar_url_as.assert_called_once_with(static_format="png")
self.assertEqual(embed.thumbnail.url, "avatar url")
@@ -491,7 +537,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):
await self.cog.user_info(self.cog, ctx)
- create_embed.assert_called_once_with(ctx, self.author)
+ create_embed.assert_called_once_with(ctx, self.author, False)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
@@ -502,28 +548,28 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):
await self.cog.user_info(self.cog, ctx, self.author)
- create_embed.assert_called_once_with(ctx, self.author)
+ create_embed.assert_called_once_with(ctx, self.author, False)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
- constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.STAFF_PARTNERS_COMMUNITY_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
await self.cog.user_info(self.cog, ctx)
- create_embed.assert_called_once_with(ctx, self.moderator)
+ create_embed.assert_called_once_with(ctx, self.moderator, False)
ctx.send.assert_called_once()
@unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")
async def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
- constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.STAFF_PARTNERS_COMMUNITY_ROLES = [self.moderator_role.id]
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50))
await self.cog.user_info(self.cog, ctx, self.target)
- create_embed.assert_called_once_with(ctx, self.target)
+ create_embed.assert_called_once_with(ctx, self.target, False)
ctx.send.assert_called_once()
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index b9d527770..052048053 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -1,11 +1,15 @@
import inspect
import textwrap
import unittest
-from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch
+from unittest.mock import ANY, AsyncMock, DEFAULT, MagicMock, Mock, patch
+
+from discord.errors import NotFound
from bot.constants import Event
+from bot.exts.moderation.clean import Clean
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction.infractions import Infractions
+from bot.exts.moderation.infraction.management import ModManagement
from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockUser, autospec
@@ -13,12 +17,13 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
"""Tests for ban and kick command reason truncation."""
def setUp(self):
+ self.me = MockMember(id=7890, roles=[MockRole(id=7890, position=5)])
self.bot = MockBot()
self.cog = Infractions(self.bot)
- self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10))
- self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0))
+ self.user = MockMember(id=1234, roles=[MockRole(id=3577, position=10)])
+ self.target = MockMember(id=1265, roles=[MockRole(id=9876, position=1)])
self.guild = MockGuild(id=4567)
- self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild)
+ self.ctx = MockContext(me=self.me, bot=self.bot, author=self.user, guild=self.guild)
@patch("bot.exts.moderation.infraction._utils.get_active_infraction")
@patch("bot.exts.moderation.infraction._utils.post_infraction")
@@ -59,70 +64,70 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)
-class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
- """Tests for voice ban related functions and commands."""
+class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for voice mute related functions and commands."""
def setUp(self):
self.bot = MockBot()
- self.mod = MockMember(top_role=10)
- self.user = MockMember(top_role=1, roles=[MockRole(id=123456)])
+ self.mod = MockMember(roles=[MockRole(id=7890123, position=10)])
+ self.user = MockMember(roles=[MockRole(id=123456, position=1)])
self.guild = MockGuild()
self.ctx = MockContext(bot=self.bot, author=self.mod)
self.cog = Infractions(self.bot)
- async def test_permanent_voice_ban(self):
- """Should call voice ban applying function without expiry."""
- self.cog.apply_voice_ban = AsyncMock()
- self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar"))
- self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None)
+ async def test_permanent_voice_mute(self):
+ """Should call voice mute applying function without expiry."""
+ self.cog.apply_voice_mute = AsyncMock()
+ self.assertIsNone(await self.cog.voicemute(self.cog, self.ctx, self.user, reason="foobar"))
+ self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at=None)
- async def test_temporary_voice_ban(self):
- """Should call voice ban applying function with expiry."""
- self.cog.apply_voice_ban = AsyncMock()
- self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar"))
- self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz")
+ async def test_temporary_voice_mute(self):
+ """Should call voice mute applying function with expiry."""
+ self.cog.apply_voice_mute = AsyncMock()
+ self.assertIsNone(await self.cog.tempvoicemute(self.cog, self.ctx, self.user, "baz", reason="foobar"))
+ self.cog.apply_voice_mute.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz")
- async def test_voice_unban(self):
+ async def test_voice_unmute(self):
"""Should call infraction pardoning function."""
self.cog.pardon_infraction = AsyncMock()
- self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user))
- self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user)
+ self.assertIsNone(await self.cog.unvoicemute(self.cog, self.ctx, self.user))
+ self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_mute", self.user)
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock):
- """Should return early when user already have Voice Ban infraction."""
+ async def test_voice_mute_user_have_active_infraction(self, get_active_infraction, post_infraction_mock):
+ """Should return early when user already have Voice Mute infraction."""
get_active_infraction.return_value = {"foo": "bar"}
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
- get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban")
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar"))
+ get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_mute")
post_infraction_mock.assert_not_awaited()
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock):
+ async def test_voice_mute_infraction_post_failed(self, get_active_infraction, post_infraction_mock):
"""Should return early when posting infraction fails."""
self.cog.mod_log.ignore = MagicMock()
get_active_infraction.return_value = None
post_infraction_mock.return_value = None
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar"))
post_infraction_mock.assert_awaited_once()
self.cog.mod_log.ignore.assert_not_called()
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock):
- """Should pass all kwargs passed to apply_voice_ban to post_infraction."""
+ async def test_voice_mute_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock):
+ """Should pass all kwargs passed to apply_voice_mute to post_infraction."""
get_active_infraction.return_value = None
# We don't want that this continue yet
post_infraction_mock.return_value = None
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23))
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar", my_kwarg=23))
post_infraction_mock.assert_awaited_once_with(
- self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23
+ self.ctx, self.user, "voice_mute", "foobar", active=True, my_kwarg=23
)
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock):
+ async def test_voice_mute_mod_log_ignore(self, get_active_infraction, post_infraction_mock):
"""Should ignore Voice Verified role removing."""
self.cog.mod_log.ignore = MagicMock()
self.cog.apply_infraction = AsyncMock()
@@ -131,11 +136,11 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
get_active_infraction.return_value = None
post_infraction_mock.return_value = {"foo": "bar"}
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar"))
self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id)
async def action_tester(self, action, reason: str) -> None:
- """Helper method to test voice ban action."""
+ """Helper method to test voice mute action."""
self.assertTrue(inspect.iscoroutine(action))
await action
@@ -144,7 +149,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock):
+ async def test_voice_mute_apply_infraction(self, get_active_infraction, post_infraction_mock):
"""Should ignore Voice Verified role removing."""
self.cog.mod_log.ignore = MagicMock()
self.cog.apply_infraction = AsyncMock()
@@ -153,22 +158,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
post_infraction_mock.return_value = {"foo": "bar"}
reason = "foobar"
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, reason))
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, reason))
self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY)
await self.action_tester(self.cog.apply_infraction.call_args[0][-1], reason)
@patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
@patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
- async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock):
- """Should truncate reason for voice ban."""
+ async def test_voice_mute_truncate_reason(self, get_active_infraction, post_infraction_mock):
+ """Should truncate reason for voice mute."""
self.cog.mod_log.ignore = MagicMock()
self.cog.apply_infraction = AsyncMock()
get_active_infraction.return_value = None
post_infraction_mock.return_value = {"foo": "bar"}
- self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000))
+ self.assertIsNone(await self.cog.apply_voice_mute(self.ctx, self.user, "foobar" * 3000))
self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, ANY)
# Test action
@@ -177,14 +182,14 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
@autospec(_utils, "post_infraction", "get_active_infraction", return_value=None)
@autospec(Infractions, "apply_infraction")
- async def test_voice_ban_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _):
- """Should voice ban user that left the guild without throwing an error."""
+ async def test_voice_mute_user_left_guild(self, apply_infraction_mock, post_infraction_mock, _):
+ """Should voice mute user that left the guild without throwing an error."""
infraction = {"foo": "bar"}
post_infraction_mock.return_value = {"foo": "bar"}
user = MockUser()
- await self.cog.voiceban(self.cog, self.ctx, user, reason=None)
- post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_ban", None, active=True, expires_at=None)
+ await self.cog.voicemute(self.cog, self.ctx, user, reason=None)
+ post_infraction_mock.assert_called_once_with(self.ctx, user, "voice_mute", None, active=True, expires_at=None)
apply_infraction_mock.assert_called_once_with(self.cog, self.ctx, infraction, user, ANY)
# Test action
@@ -192,21 +197,22 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(inspect.iscoroutine(action))
await action
- async def test_voice_unban_user_not_found(self):
+ async def test_voice_unmute_user_not_found(self):
"""Should include info to return dict when user was not found from guild."""
self.guild.get_member.return_value = None
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.guild.fetch_member.side_effect = NotFound(Mock(status=404), "Not found")
+ result = await self.cog.pardon_voice_mute(self.user.id, self.guild)
self.assertEqual(result, {"Info": "User was not found in the guild."})
@patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
@patch("bot.exts.moderation.infraction.infractions.format_user")
- async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock):
+ async def test_voice_unmute_user_found(self, format_user_mock, notify_pardon_mock):
"""Should add role back with ignoring, notify user and return log dictionary.."""
self.guild.get_member.return_value = self.user
notify_pardon_mock.return_value = True
format_user_mock.return_value = "my-user"
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ result = await self.cog.pardon_voice_mute(self.user.id, self.guild)
self.assertEqual(result, {
"Member": "my-user",
"DM": "Sent"
@@ -215,15 +221,100 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
@patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
@patch("bot.exts.moderation.infraction.infractions.format_user")
- async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock):
+ async def test_voice_unmute_dm_fail(self, format_user_mock, notify_pardon_mock):
"""Should add role back with ignoring, notify user and return log dictionary.."""
self.guild.get_member.return_value = self.user
notify_pardon_mock.return_value = False
format_user_mock.return_value = "my-user"
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ result = await self.cog.pardon_voice_mute(self.user.id, self.guild)
self.assertEqual(result, {
"Member": "my-user",
"DM": "**Failed**"
})
notify_pardon_mock.assert_awaited_once()
+
+
+class CleanBanTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for cleanban functionality."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.mod = MockMember(roles=[MockRole(id=7890123, position=10)])
+ self.user = MockMember(roles=[MockRole(id=123456, position=1)])
+ self.guild = MockGuild()
+ self.ctx = MockContext(bot=self.bot, author=self.mod)
+ self.cog = Infractions(self.bot)
+ self.clean_cog = Clean(self.bot)
+ self.management_cog = ModManagement(self.bot)
+
+ self.cog.apply_ban = AsyncMock(return_value={"id": 42})
+ self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
+ self.clean_cog._clean_messages = AsyncMock(return_value=self.log_url)
+
+ def mock_get_cog(self, enable_clean, enable_manage):
+ """Mock get cog factory that allows the user to specify whether clean and manage cogs are enabled."""
+ def inner(name):
+ if name == "ModManagement":
+ return self.management_cog if enable_manage else None
+ elif name == "Clean":
+ return self.clean_cog if enable_clean else None
+ else:
+ return DEFAULT
+ return inner
+
+ async def test_cleanban_falls_back_to_native_purge_without_clean_cog(self):
+ """Should fallback to native purge if the Clean cog is not available."""
+ self.bot.get_cog.side_effect = self.mock_get_cog(False, False)
+
+ self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar"))
+ self.cog.apply_ban.assert_awaited_once_with(
+ self.ctx,
+ self.user,
+ "FooBar",
+ purge_days=1,
+ expires_at=None,
+ )
+
+ async def test_cleanban_doesnt_purge_messages_if_clean_cog_available(self):
+ """Cleanban command should use the native purge messages if the clean cog is available."""
+ self.bot.get_cog.side_effect = self.mock_get_cog(True, False)
+
+ self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar"))
+ self.cog.apply_ban.assert_awaited_once_with(
+ self.ctx,
+ self.user,
+ "FooBar",
+ expires_at=None,
+ )
+
+ @patch("bot.exts.moderation.infraction.infractions.Age")
+ async def test_cleanban_uses_clean_cog_when_available(self, mocked_age_converter):
+ """Test cleanban uses the clean cog to clean messages if it's available."""
+ self.bot.api_client.patch = AsyncMock()
+ self.bot.get_cog.side_effect = self.mock_get_cog(True, False)
+
+ mocked_age_converter.return_value.convert = AsyncMock(return_value="81M")
+ self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar"))
+
+ self.clean_cog._clean_messages.assert_awaited_once_with(
+ self.ctx,
+ users=[self.user],
+ channels="*",
+ first_limit="81M",
+ attempt_delete_invocation=False,
+ )
+
+ async def test_cleanban_edits_infraction_reason(self):
+ """Ensure cleanban edits the ban reason with a link to the clean log."""
+ self.bot.get_cog.side_effect = self.mock_get_cog(True, True)
+
+ self.management_cog.infraction_append = AsyncMock()
+ self.assertIsNone(await self.cog.cleanban(self.cog, self.ctx, self.user, None, reason="FooBar"))
+
+ self.management_cog.infraction_append.assert_awaited_once_with(
+ self.ctx,
+ {"id": 42},
+ None,
+ reason=f"[Clean log]({self.log_url})"
+ )
diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py
index d35120992..ff81ddd65 100644
--- a/tests/bot/exts/moderation/infraction/test_utils.py
+++ b/tests/bot/exts/moderation/infraction/test_utils.py
@@ -15,7 +15,10 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
"""Tests Moderation utils."""
def setUp(self):
- self.bot = MockBot()
+ patcher = patch("bot.instance", new=MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
+
self.member = MockMember(id=1234)
self.user = MockUser(id=1234)
self.ctx = MockContext(bot=self.bot, author=self.member)
@@ -94,8 +97,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"])
test_cases = [
test_case([], None, None, True),
- test_case([{"id": 123987}], {"id": 123987}, "123987", False),
- test_case([{"id": 123987}], {"id": 123987}, "123987", True)
+ test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False),
+ test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True)
]
for case in test_cases:
@@ -123,6 +126,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
else:
self.ctx.send.assert_not_awaited()
+ @unittest.skip("Current time needs to be patched so infraction duration is correct.")
@patch("bot.exts.moderation.infraction._utils.send_private_embed")
async def test_send_infraction_embed(self, send_private_embed_mock):
"""
@@ -132,95 +136,95 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
"""
test_cases = [
{
- "args": (self.user, "ban", "2020-02-26 09:20 (23 hours and 59 minutes)"),
+ "args": (dict(id=0, type="ban", reason=None, expires_at=datetime(2020, 2, 26, 9, 20)), self.user),
"expected_output": Embed(
title=utils.INFRACTION_TITLE,
description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(
type="Ban",
- expires="2020-02-26 09:20 (23 hours and 59 minutes) UTC",
+ expires="2020-02-26 09:20 (23 hours and 59 minutes)",
reason="No reason provided."
- ),
+ ) + utils.INFRACTION_APPEAL_SERVER_FOOTER,
colour=Colours.soft_red,
url=utils.RULES_URL
).set_author(
name=utils.INFRACTION_AUTHOR_NAME,
url=utils.RULES_URL,
- icon_url=Icons.token_removed
- ).set_footer(text=utils.INFRACTION_APPEAL_MODMAIL_FOOTER),
+ icon_url=Icons.user_ban
+ ),
"send_result": True
},
{
- "args": (self.user, "warning", None, "Test reason."),
+ "args": (dict(id=0, type="warning", reason="Test reason.", expires_at=None), self.user),
"expected_output": Embed(
title=utils.INFRACTION_TITLE,
description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(
type="Warning",
expires="N/A",
reason="Test reason."
- ),
+ ) + utils.INFRACTION_APPEAL_MODMAIL_FOOTER,
colour=Colours.soft_red,
url=utils.RULES_URL
).set_author(
name=utils.INFRACTION_AUTHOR_NAME,
url=utils.RULES_URL,
- icon_url=Icons.token_removed
- ).set_footer(text=utils.INFRACTION_APPEAL_MODMAIL_FOOTER),
+ icon_url=Icons.user_warn
+ ),
"send_result": False
},
# Note that this test case asserts that the DM that *would* get sent to the user is formatted
# correctly, even though that message is deliberately never sent.
{
- "args": (self.user, "note", None, None, Icons.defcon_denied),
+ "args": (dict(id=0, type="note", reason=None, expires_at=None), self.user),
"expected_output": Embed(
title=utils.INFRACTION_TITLE,
description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(
type="Note",
expires="N/A",
reason="No reason provided."
- ),
+ ) + utils.INFRACTION_APPEAL_MODMAIL_FOOTER,
colour=Colours.soft_red,
url=utils.RULES_URL
).set_author(
name=utils.INFRACTION_AUTHOR_NAME,
url=utils.RULES_URL,
- icon_url=Icons.defcon_denied
- ).set_footer(text=utils.INFRACTION_APPEAL_MODMAIL_FOOTER),
+ icon_url=Icons.user_warn
+ ),
"send_result": False
},
{
- "args": (self.user, "mute", "2020-02-26 09:20 (23 hours and 59 minutes)", "Test", Icons.defcon_denied),
+ "args": (dict(id=0, type="mute", reason="Test", expires_at=datetime(2020, 2, 26, 9, 20)), self.user),
"expected_output": Embed(
title=utils.INFRACTION_TITLE,
description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(
type="Mute",
- expires="2020-02-26 09:20 (23 hours and 59 minutes) UTC",
+ expires="2020-02-26 09:20 (23 hours and 59 minutes)",
reason="Test"
- ),
+ ) + utils.INFRACTION_APPEAL_MODMAIL_FOOTER,
colour=Colours.soft_red,
url=utils.RULES_URL
).set_author(
name=utils.INFRACTION_AUTHOR_NAME,
url=utils.RULES_URL,
- icon_url=Icons.defcon_denied
- ).set_footer(text=utils.INFRACTION_APPEAL_MODMAIL_FOOTER),
+ icon_url=Icons.user_mute
+ ),
"send_result": False
},
{
- "args": (self.user, "mute", None, "foo bar" * 4000, Icons.defcon_denied),
+ "args": (dict(id=0, type="mute", reason="foo bar" * 4000, expires_at=None), self.user),
"expected_output": Embed(
title=utils.INFRACTION_TITLE,
description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(
type="Mute",
expires="N/A",
reason="foo bar" * 4000
- )[:2045] + "...",
+ )[:4093-utils.LONGEST_EXTRAS] + "..." + utils.INFRACTION_APPEAL_MODMAIL_FOOTER,
colour=Colours.soft_red,
url=utils.RULES_URL
).set_author(
name=utils.INFRACTION_AUTHOR_NAME,
url=utils.RULES_URL,
- icon_url=Icons.defcon_denied
- ).set_footer(text=utils.INFRACTION_APPEAL_MODMAIL_FOOTER),
+ icon_url=Icons.user_mute
+ ),
"send_result": True
}
]
@@ -230,7 +234,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
send_private_embed_mock.reset_mock()
send_private_embed_mock.return_value = case["send_result"]
- result = await utils.send_infraction_embed(*case["args"])
+ result = await utils.notify_infraction(*case["args"])
self.assertEqual(case["send_result"], result)
@@ -238,7 +242,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(embed.to_dict(), case["expected_output"].to_dict())
- send_private_embed_mock.assert_awaited_once_with(case["args"][0], embed)
+ send_private_embed_mock.assert_awaited_once_with(case["args"][1], embed)
@patch("bot.exts.moderation.infraction._utils.send_private_embed")
async def test_notify_pardon(self, send_private_embed_mock):
@@ -313,7 +317,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):
"type": "ban",
"user": self.member.id,
"active": False,
- "expires_at": now.isoformat()
+ "expires_at": now.isoformat(),
+ "dm_sent": False
}
self.ctx.bot.api_client.post.return_value = "foo"
@@ -350,7 +355,8 @@ class TestPostInfraction(unittest.IsolatedAsyncioTestCase):
"reason": "Test reason",
"type": "mute",
"user": self.user.id,
- "active": True
+ "active": True,
+ "dm_sent": False
}
self.bot.api_client.post.side_effect = [ResponseCodeError(MagicMock(status=400), {"user": "foo"}), "foo"]
diff --git a/tests/bot/exts/moderation/test_clean.py b/tests/bot/exts/moderation/test_clean.py
new file mode 100644
index 000000000..d7647fa48
--- /dev/null
+++ b/tests/bot/exts/moderation/test_clean.py
@@ -0,0 +1,104 @@
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from bot.exts.moderation.clean import Clean
+from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockMessage, MockRole, MockTextChannel
+
+
+class CleanTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for clean cog functionality."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.mod = MockMember(roles=[MockRole(id=7890123, position=10)])
+ self.user = MockMember(roles=[MockRole(id=123456, position=1)])
+ self.guild = MockGuild()
+ self.ctx = MockContext(bot=self.bot, author=self.mod)
+ self.cog = Clean(self.bot)
+
+ self.log_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
+ self.cog._modlog_cleaned_messages = AsyncMock(return_value=self.log_url)
+
+ self.cog._use_cache = MagicMock(return_value=True)
+ self.cog._delete_found = AsyncMock(return_value=[42, 84])
+
+ @patch("bot.exts.moderation.clean.is_mod_channel")
+ async def test_clean_deletes_invocation_in_non_mod_channel(self, mod_channel_check):
+ """Clean command should delete the invocation message if ran in a non mod channel."""
+ mod_channel_check.return_value = False
+ self.ctx.message.delete = AsyncMock()
+
+ self.assertIsNone(await self.cog._delete_invocation(self.ctx))
+
+ self.ctx.message.delete.assert_awaited_once()
+
+ @patch("bot.exts.moderation.clean.is_mod_channel")
+ async def test_clean_doesnt_delete_invocation_in_mod_channel(self, mod_channel_check):
+ """Clean command should not delete the invocation message if ran in a mod channel."""
+ mod_channel_check.return_value = True
+ self.ctx.message.delete = AsyncMock()
+
+ self.assertIsNone(await self.cog._delete_invocation(self.ctx))
+
+ self.ctx.message.delete.assert_not_awaited()
+
+ async def test_clean_doesnt_attempt_deletion_when_attempt_delete_invocation_is_false(self):
+ """Clean command should not attempt to delete the invocation message if attempt_delete_invocation is false."""
+ self.cog._delete_invocation = AsyncMock()
+ self.bot.get_channel = MagicMock(return_value=False)
+
+ self.assertEqual(
+ await self.cog._clean_messages(
+ self.ctx,
+ None,
+ first_limit=MockMessage(),
+ attempt_delete_invocation=False,
+ ),
+ self.log_url,
+ )
+
+ self.cog._delete_invocation.assert_not_awaited()
+
+ @patch("bot.exts.moderation.clean.is_mod_channel")
+ async def test_clean_replies_with_success_message_when_ran_in_mod_channel(self, mod_channel_check):
+ """Clean command should reply to the message with a confirmation message if invoked in a mod channel."""
+ mod_channel_check.return_value = True
+ self.ctx.reply = AsyncMock()
+
+ self.assertEqual(
+ await self.cog._clean_messages(
+ self.ctx,
+ None,
+ first_limit=MockMessage(),
+ attempt_delete_invocation=False,
+ ),
+ self.log_url,
+ )
+
+ self.ctx.reply.assert_awaited_once()
+ sent_message = self.ctx.reply.await_args[0][0]
+ self.assertIn(self.log_url, sent_message)
+ self.assertIn("2 messages", sent_message)
+
+ @patch("bot.exts.moderation.clean.is_mod_channel")
+ async def test_clean_send_success_message_to_mods_when_ran_in_non_mod_channel(self, mod_channel_check):
+ """Clean command should send a confirmation message to #mods if invoked in a non-mod channel."""
+ mod_channel_check.return_value = False
+ mocked_mods = MockTextChannel(id=1234567)
+ mocked_mods.send = AsyncMock()
+ self.bot.get_channel = MagicMock(return_value=mocked_mods)
+
+ self.assertEqual(
+ await self.cog._clean_messages(
+ self.ctx,
+ None,
+ first_limit=MockMessage(),
+ attempt_delete_invocation=False,
+ ),
+ self.log_url,
+ )
+
+ mocked_mods.send.assert_awaited_once()
+ sent_message = mocked_mods.send.await_args[0][0]
+ self.assertIn(self.log_url, sent_message)
+ self.assertIn("2 messages", sent_message)
diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py
index cbf7f7bcf..cfe0c4b03 100644
--- a/tests/bot/exts/moderation/test_incidents.py
+++ b/tests/bot/exts/moderation/test_incidents.py
@@ -3,23 +3,19 @@ import enum
import logging
import typing as t
import unittest
-from unittest.mock import AsyncMock, MagicMock, call, patch
+from unittest import mock
+from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
import aiohttp
import discord
+from async_rediscache import RedisSession
from bot.constants import Colours
from bot.exts.moderation import incidents
+from bot.utils.messages import format_user
from tests.helpers import (
- MockAsyncWebhook,
- MockAttachment,
- MockBot,
- MockMember,
- MockMessage,
- MockReaction,
- MockRole,
- MockTextChannel,
- MockUser,
+ MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,
+ MockUser
)
@@ -283,6 +279,22 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):
the instance as they wish.
"""
+ session = None
+
+ async def flush(self):
+ """Flush everything from the database to prevent carry-overs between tests."""
+ with await self.session.pool as connection:
+ await connection.flushall()
+
+ async def asyncSetUp(self): # noqa: N802
+ self.session = RedisSession(use_fakeredis=True)
+ await self.session.connect()
+ await self.flush()
+
+ async def asyncTearDown(self): # noqa: N802
+ if self.session:
+ await self.session.close()
+
def setUp(self):
"""
Prepare a fresh `Incidents` instance for each test.
@@ -379,7 +391,7 @@ class TestArchive(TestIncidents):
# Define our own `incident` to be archived
incident = MockMessage(
content="this is an incident",
- author=MockUser(name="author_name", avatar_url="author_avatar"),
+ author=MockUser(name="author_name", display_avatar=Mock(url="author_avatar")),
id=123,
)
built_embed = MagicMock(discord.Embed, id=123) # We patch `make_embed` to return this
@@ -513,7 +525,7 @@ class TestProcessEvent(TestIncidents):
with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):
await self.cog_instance.process_event(
reaction=incidents.Signal.ACTIONED.value,
- incident=MockMessage(),
+ incident=MockMessage(id=123),
member=MockMember(roles=[MockRole(id=1)])
)
@@ -533,7 +545,7 @@ class TestProcessEvent(TestIncidents):
with patch("bot.exts.moderation.incidents.Incidents.make_confirmation_task", mock_task):
await self.cog_instance.process_event(
reaction=incidents.Signal.ACTIONED.value,
- incident=MockMessage(),
+ incident=MockMessage(id=123),
member=MockMember(roles=[MockRole(id=1)])
)
except asyncio.TimeoutError:
@@ -768,3 +780,74 @@ class TestOnMessage(TestIncidents):
await self.cog_instance.on_message(MockMessage())
mock_add_signals.assert_not_called()
+
+
+class TestMessageLinkEmbeds(TestIncidents):
+ """Tests for `extract_message_links` coroutine."""
+
+ async def test_shorten_text(self):
+ """Test all cases of text shortening by mocking messages."""
+ tests = {
+ "thisisasingleword"*10: "thisisasinglewordthisisasinglewordthisisasinglewor...",
+
+ "\n".join("Lets make a new line test".split()): "Lets\nmake\na...",
+
+ 'Hello, World!' * 300: (
+ "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!"
+ "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!"
+ "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!"
+ "Hello, World!Hello, World!H..."
+ )
+ }
+
+ for content, expected_conversion in tests.items():
+ with self.subTest(content=content, expected_conversion=expected_conversion):
+ conversion = incidents.shorten_text(content)
+ self.assertEqual(conversion, expected_conversion)
+
+ async def extract_and_form_message_link_embeds(self):
+ """
+ Extract message links from a mocked message and form the message link embed.
+
+ Considers all types of message links, discord supports.
+ """
+ self.guild_id_patcher = mock.patch("bot.exts.backend.sync._cog.constants.Guild.id", 5)
+ self.guild_id = self.guild_id_patcher.start()
+
+ msg = MockMessage(id=555, content="Hello, World!" * 3000)
+ msg.channel.mention = "#lemonade-stand"
+
+ msg_links = [
+ # Valid Message links
+ f"https://discord.com/channels/{self.guild_id}/{msg.channel.discord_id}/{msg.discord_id}",
+ f"http://canary.discord.com/channels/{self.guild_id}/{msg.channel.discord_id}/{msg.discord_id}",
+
+ # Invalid Message links
+ f"https://discord.com/channels/{msg.channel.discord_id}/{msg.discord_id}",
+ f"https://discord.com/channels/{self.guild_id}/{msg.channel.discord_id}000/{msg.discord_id}",
+ ]
+
+ incident_msg = MockMessage(
+ id=777,
+ content=(
+ f"I would like to report the following messages, "
+ f"as they break our rules: \n{', '.join(msg_links)}"
+ )
+ )
+
+ with patch(
+ "bot.exts.moderation.incidents.Incidents.extract_message_links", AsyncMock()
+ ) as mock_extract_message_links:
+ embeds = mock_extract_message_links(incident_msg)
+ description = (
+ f"**Author:** {format_user(msg.author)}\n"
+ f"**Channel:** {msg.channel.mention} ({msg.channel.category}/#{msg.channel.name})\n"
+ f"**Content:** {('Hello, World!' * 3000)[:300] + '...'}\n"
+ )
+
+ # Check number of embeds returned with number of valid links
+ self.assertEqual(len(embeds), 2)
+
+ # Check for the embed descriptions
+ for embed in embeds:
+ self.assertEqual(embed.description, description)
diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py
index f8f142484..79e04837d 100644
--- a/tests/bot/exts/moderation/test_modlog.py
+++ b/tests/bot/exts/moderation/test_modlog.py
@@ -25,5 +25,5 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase):
)
embed = self.channel.send.call_args[1]["embed"]
self.assertEqual(
- embed.description, ("foo bar" * 3000)[:2045] + "..."
+ embed.description, ("foo bar" * 3000)[:4093] + "..."
)
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 59a5893ef..92ce3418a 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -12,14 +12,7 @@ from discord import PermissionOverwrite
from bot.constants import Channels, Guild, MODERATION_ROLES, Roles
from bot.exts.moderation import silence
from tests.helpers import (
- MockBot,
- MockContext,
- MockGuild,
- MockMember,
- MockRole,
- MockTextChannel,
- MockVoiceChannel,
- autospec
+ MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec
)
redis_session = None
@@ -438,7 +431,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
asyncio.run(self.cog._async_init()) # Populate instance attributes.
self.text_channel = MockTextChannel()
- self.text_overwrite = PermissionOverwrite(send_messages=True, add_reactions=False)
+ self.text_overwrite = PermissionOverwrite(
+ send_messages=True,
+ add_reactions=False,
+ create_private_threads=True,
+ create_public_threads=False,
+ send_messages_in_threads=True
+ )
self.text_channel.overwrites_for.return_value = self.text_overwrite
self.voice_channel = MockVoiceChannel()
@@ -509,9 +508,39 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_skipped_already_silenced(self):
"""Permissions were not set and `False` was returned for an already silenced channel."""
subtests = (
- (False, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)),
- (True, MockTextChannel(), PermissionOverwrite(send_messages=True, add_reactions=True)),
- (True, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (
+ False,
+ MockTextChannel(),
+ PermissionOverwrite(
+ send_messages=False,
+ add_reactions=False,
+ create_private_threads=False,
+ create_public_threads=False,
+ send_messages_in_threads=False
+ )
+ ),
+ (
+ True,
+ MockTextChannel(),
+ PermissionOverwrite(
+ send_messages=True,
+ add_reactions=True,
+ create_private_threads=True,
+ create_public_threads=True,
+ send_messages_in_threads=True
+ )
+ ),
+ (
+ True,
+ MockTextChannel(),
+ PermissionOverwrite(
+ send_messages=False,
+ add_reactions=False,
+ create_private_threads=False,
+ create_public_threads=False,
+ send_messages_in_threads=False
+ )
+ ),
(False, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)),
(True, MockVoiceChannel(), PermissionOverwrite(connect=True, speak=True)),
(True, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)),
@@ -559,11 +588,16 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
await self.cog._set_silence_overwrites(self.text_channel)
new_overwrite_dict = dict(self.text_overwrite)
- # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.
- del prev_overwrite_dict['send_messages']
- del prev_overwrite_dict['add_reactions']
- del new_overwrite_dict['send_messages']
- del new_overwrite_dict['add_reactions']
+ # Remove related permission keys because they were changed by the method.
+ for perm_name in (
+ "send_messages",
+ "add_reactions",
+ "create_private_threads",
+ "create_public_threads",
+ "send_messages_in_threads"
+ ):
+ del prev_overwrite_dict[perm_name]
+ del new_overwrite_dict[perm_name]
self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
@@ -601,7 +635,10 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_cached_previous_overwrites(self):
"""Channel's previous overwrites were cached."""
- overwrite_json = '{"send_messages": true, "add_reactions": false}'
+ overwrite_json = (
+ '{"send_messages": true, "add_reactions": false, "create_private_threads": true, '
+ '"create_public_threads": false, "send_messages_in_threads": true}'
+ )
await self.cog._set_silence_overwrites(self.text_channel)
self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 321a92445..8bdeedd27 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -2,6 +2,7 @@ import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch
+from discord import AllowedMentions
from discord.ext import commands
from bot import constants
@@ -201,7 +202,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx = MockContext()
ctx.message = MockMessage()
ctx.send = AsyncMock()
- ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.author = MockUser(mention='@LemonLemonishBeard#0042')
self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
@@ -213,9 +214,16 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_cog.return_value = mocked_filter_cog
await self.cog.send_eval(ctx, 'MyAwesomeCode')
- ctx.send.assert_called_once_with(
+
+ ctx.send.assert_called_once()
+ self.assertEqual(
+ ctx.send.call_args.args[0],
'@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```\n[No output]\n```'
)
+ allowed_mentions = ctx.send.call_args.kwargs['allowed_mentions']
+ expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
+ self.assertEqual(allowed_mentions.to_dict(), expected_allowed_mentions.to_dict())
+
self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
@@ -238,10 +246,14 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_cog.return_value = mocked_filter_cog
await self.cog.send_eval(ctx, 'MyAwesomeCode')
- ctx.send.assert_called_once_with(
+
+ ctx.send.assert_called_once()
+ self.assertEqual(
+ ctx.send.call_args.args[0],
'@LemonLemonishBeard#0042 :yay!: Return code 0.'
'\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com'
)
+
self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
@@ -263,9 +275,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.bot.get_cog.return_value = mocked_filter_cog
await self.cog.send_eval(ctx, 'MyAwesomeCode')
- ctx.send.assert_called_once_with(
+
+ ctx.send.assert_called_once()
+ self.assertEqual(
+ ctx.send.call_args.args[0],
'@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```\nBeard got stuck in the eval\n```'
)
+
self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index 6444532f2..f8805ac48 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -2,12 +2,14 @@ from typing import Iterable
from bot.rules import mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
+from tests.helpers import MockMember, MockMessage
-def make_msg(author: str, total_mentions: int) -> MockMessage:
+def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
- return MockMessage(author=author, mentions=list(range(total_mentions)))
+ user_mentions = [MockMember() for _ in range(total_user_mentions)]
+ bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
+ return MockMessage(author=author, mentions=user_mentions+bot_mentions)
class TestMentions(RuleTest):
@@ -48,11 +50,27 @@ class TestMentions(RuleTest):
[make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
4,
- )
+ ),
+ DisallowedCase(
+ [make_msg("bob", 3, 1)],
+ ("bob",),
+ 3,
+ ),
)
await self.run_disallowed(cases)
+ async def test_ignore_bot_mentions(self):
+ """Messages with an allowed amount of mentions, also containing bot mentions."""
+ cases = (
+ [make_msg("bob", 0, 3)],
+ [make_msg("bob", 2, 1)],
+ [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
+ [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
+ )
+
+ await self.run_allowed(cases)
+
def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
last_message = case.recent_messages[0]
return tuple(
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index 2a1c4e543..1bb678db2 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -1,19 +1,12 @@
-import datetime
import re
import unittest
+from datetime import MAXYEAR, datetime, timezone
from unittest.mock import MagicMock, patch
from dateutil.relativedelta import relativedelta
from discord.ext.commands import BadArgument
-from bot.converters import (
- Duration,
- HushDurationConverter,
- ISODateTime,
- PackageName,
- TagContentConverter,
- TagNameConverter,
-)
+from bot.converters import Duration, HushDurationConverter, ISODateTime, PackageName
class ConverterTests(unittest.IsolatedAsyncioTestCase):
@@ -24,59 +17,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
cls.context = MagicMock
cls.context.author = 'bob'
- cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
-
- async def test_tag_content_converter_for_valid(self):
- """TagContentConverter should return correct values for valid input."""
- test_values = (
- ('hello', 'hello'),
- (' h ello ', 'h ello'),
- )
-
- for content, expected_conversion in test_values:
- with self.subTest(content=content, expected_conversion=expected_conversion):
- conversion = await TagContentConverter.convert(self.context, content)
- self.assertEqual(conversion, expected_conversion)
-
- async def test_tag_content_converter_for_invalid(self):
- """TagContentConverter should raise the proper exception for invalid input."""
- test_values = (
- ('', "Tag contents should not be empty, or filled with whitespace."),
- (' ', "Tag contents should not be empty, or filled with whitespace."),
- )
-
- for value, exception_message in test_values:
- with self.subTest(tag_content=value, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
- await TagContentConverter.convert(self.context, value)
-
- async def test_tag_name_converter_for_valid(self):
- """TagNameConverter should return the correct values for valid tag names."""
- test_values = (
- ('tracebacks', 'tracebacks'),
- ('Tracebacks', 'tracebacks'),
- (' Tracebacks ', 'tracebacks'),
- )
-
- for name, expected_conversion in test_values:
- with self.subTest(name=name, expected_conversion=expected_conversion):
- conversion = await TagNameConverter.convert(self.context, name)
- self.assertEqual(conversion, expected_conversion)
-
- async def test_tag_name_converter_for_invalid(self):
- """TagNameConverter should raise the correct exception for invalid tag names."""
- test_values = (
- ('👋', "Don't be ridiculous, you can't use that character!"),
- ('', "Tag names should not be empty, or filled with whitespace."),
- (' ', "Tag names should not be empty, or filled with whitespace."),
- ('42', "Tag names must contain at least one letter."),
- ('x' * 128, "Are you insane? That's way too long!"),
- )
-
- for invalid_name, exception_message in test_values:
- with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
- await TagNameConverter.convert(self.context, invalid_name)
+ cls.fixed_utc_now = datetime.fromisoformat('2019-01-01T00:00:00+00:00')
async def test_package_name_for_valid(self):
"""PackageName returns valid package names unchanged."""
@@ -155,7 +96,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
expected_datetime = self.fixed_utc_now + relativedelta(**duration_dict)
with patch('bot.converters.datetime') as mock_datetime:
- mock_datetime.utcnow.return_value = self.fixed_utc_now
+ mock_datetime.now.return_value = self.fixed_utc_now
with self.subTest(duration=duration, duration_dict=duration_dict):
converted_datetime = await converter.convert(self.context, duration)
@@ -201,52 +142,53 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
async def test_duration_converter_out_of_range(self, mock_datetime):
"""Duration converter should raise BadArgument if datetime raises a ValueError."""
mock_datetime.__add__.side_effect = ValueError
- mock_datetime.utcnow.return_value = mock_datetime
+ mock_datetime.now.return_value = mock_datetime
- duration = f"{datetime.MAXYEAR}y"
+ duration = f"{MAXYEAR}y"
exception_message = f"`{duration}` results in a datetime outside the supported range."
with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
await Duration().convert(self.context, duration)
async def test_isodatetime_converter_for_valid(self):
"""ISODateTime converter returns correct datetime for valid datetime string."""
+ utc = timezone.utc
test_values = (
# `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
- ('2019-09-02T02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T02:03:05Z', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 02:03:05Z', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
# `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM`
- ('2019-09-02T03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02T00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T03:18:05+01:15', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 03:18:05+01:15', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02T00:48:05-01:15', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 00:48:05-01:15', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
# `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM`
- ('2019-09-02T03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02T00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T03:18:05+0115', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 03:18:05+0115', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02T00:48:05-0115', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 00:48:05-0115', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
# `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH`
- ('2019-09-02 03:03:05+01', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02T01:03:05-01', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 03:03:05+01', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02T01:03:05-01', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
# `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS`
- ('2019-09-02T02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)),
- ('2019-09-02 02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T02:03:05', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
+ ('2019-09-02 02:03:05', datetime(2019, 9, 2, 2, 3, 5, tzinfo=utc)),
# `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM`
- ('2019-11-12T09:15', datetime.datetime(2019, 11, 12, 9, 15)),
- ('2019-11-12 09:15', datetime.datetime(2019, 11, 12, 9, 15)),
+ ('2019-11-12T09:15', datetime(2019, 11, 12, 9, 15, tzinfo=utc)),
+ ('2019-11-12 09:15', datetime(2019, 11, 12, 9, 15, tzinfo=utc)),
# `YYYY-mm-dd`
- ('2019-04-01', datetime.datetime(2019, 4, 1)),
+ ('2019-04-01', datetime(2019, 4, 1, tzinfo=utc)),
# `YYYY-mm`
- ('2019-02-01', datetime.datetime(2019, 2, 1)),
+ ('2019-02-01', datetime(2019, 2, 1, tzinfo=utc)),
# `YYYY`
- ('2025', datetime.datetime(2025, 1, 1)),
+ ('2025', datetime(2025, 1, 1, tzinfo=utc)),
)
converter = ISODateTime()
@@ -254,7 +196,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
for datetime_string, expected_dt in test_values:
with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt):
converted_dt = await converter.convert(self.context, datetime_string)
- self.assertIsNone(converted_dt.tzinfo)
self.assertEqual(converted_dt, expected_dt)
async def test_isodatetime_converter_for_invalid(self):
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 883465e0b..4ae11d5d3 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -32,6 +32,7 @@ class ChecksTests(unittest.IsolatedAsyncioTestCase):
async def test_has_no_roles_check_without_guild(self):
"""`has_no_roles_check` should return `False` when `Context.guild` is None."""
self.ctx.channel = MagicMock(DMChannel)
+ self.ctx.guild = None
self.assertFalse(await checks.has_no_roles_check(self.ctx))
async def test_has_no_roles_check_returns_false_with_unwanted_role(self):
diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py
new file mode 100644
index 000000000..04bfd28d1
--- /dev/null
+++ b/tests/bot/utils/test_message_cache.py
@@ -0,0 +1,214 @@
+import unittest
+
+from bot.utils.message_cache import MessageCache
+from tests.helpers import MockMessage
+
+
+# noinspection SpellCheckingInspection
+class TestMessageCache(unittest.TestCase):
+ """Tests for the MessageCache class in the `bot.utils.caching` module."""
+
+ def test_first_append_sets_the_first_value(self):
+ """Test if the first append adds the message to the first cell."""
+ cache = MessageCache(maxlen=10)
+ message = MockMessage()
+
+ cache.append(message)
+
+ self.assertEqual(cache[0], message)
+
+ def test_append_adds_in_the_right_order(self):
+ """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise."""
+ messages = [MockMessage(), MockMessage()]
+
+ cache = MessageCache(maxlen=10, newest_first=False)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages, list(cache))
+
+ cache = MessageCache(maxlen=10, newest_first=True)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages[::-1], list(cache))
+
+ def test_appending_over_maxlen_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages."""
+ cache = MessageCache(maxlen=2)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_appending_over_maxlen_with_newest_first_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True."""
+ cache = MessageCache(maxlen=2, newest_first=True)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[:0:-1], list(cache))
+
+ def test_pop_removes_from_the_end(self):
+ """Test if a pop removes the right-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.pop()
+
+ self.assertEqual(msg, messages[-1])
+ self.assertListEqual(messages[:-1], list(cache))
+
+ def test_popleft_removes_from_the_beginning(self):
+ """Test if a popleft removes the left-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.popleft()
+
+ self.assertEqual(msg, messages[0])
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_clear(self):
+ """Test if a clear makes the cache empty."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ cache.clear()
+
+ self.assertListEqual(list(cache), [])
+ self.assertEqual(len(cache), 0)
+
+ def test_get_message_returns_the_message(self):
+ """Test if get_message returns the cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertEqual(cache.get_message(1234), message)
+
+ def test_get_message_returns_none(self):
+ """Test if get_message returns None for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIsNone(cache.get_message(4321))
+
+ def test_update_replaces_old_element(self):
+ """Test if an update replaced the old message with the same ID."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+ message = MockMessage(id=1234)
+ cache.update(message)
+
+ self.assertIs(cache.get_message(1234), message)
+ self.assertEqual(len(cache), 1)
+
+ def test_contains_returns_true_for_cached_message(self):
+ """Test if contains returns True for an ID of a cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIn(1234, cache)
+
+ def test_contains_returns_false_for_non_cached_message(self):
+ """Test if contains returns False for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertNotIn(4321, cache)
+
+ def test_indexing(self):
+ """Test if the cache returns the correct messages by index."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(5)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in range(-5, 5):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(cache[current_loop], messages[current_loop])
+
+ def test_bad_index_raises_index_error(self):
+ """Test if the cache raises IndexError for invalid indices."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+ test_cases = (-10, -4, 3, 4, 5)
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with self.assertRaises(IndexError):
+ cache[current_loop]
+
+ def test_slicing_with_unfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache is not yet fully filled."""
+ sizes = (5, 10, 55, 101)
+
+ slices = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1)
+ )
+
+ for size in sizes:
+ cache = MessageCache(maxlen=size)
+ messages = [MockMessage() for _ in range(size // 3 * 2)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for slice_ in slices:
+ with self.subTest(current_loop=(size, slice_)):
+ self.assertListEqual(cache[slice_], messages[slice_])
+
+ def test_slicing_with_overfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache was appended with more messages it can contain."""
+ sizes = (5, 10, 55, 101)
+
+ slices = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1)
+ )
+
+ for size in sizes:
+ cache = MessageCache(maxlen=size)
+ messages = [MockMessage() for _ in range(size * 3 // 2)]
+
+ for msg in messages:
+ cache.append(msg)
+ messages = messages[size // 2:]
+
+ for slice_ in slices:
+ with self.subTest(current_loop=(size, slice_)):
+ self.assertListEqual(cache[slice_], messages[slice_])
+
+ def test_length(self):
+ """Test if len returns the correct number of items in the cache."""
+ cache = MessageCache(maxlen=5)
+
+ for current_loop in range(10):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(len(cache), min(current_loop, 5))
+ cache.append(MockMessage())
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 115ddfb0d..120d65176 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -13,13 +13,15 @@ class TimeTests(unittest.TestCase):
"""humanize_delta should be able to handle unknown units, and will not abort."""
# Does not abort for unknown units, as the unit name is checked
# against the attribute of the relativedelta instance.
- self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours')
+ actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='elephants', max_units=2)
+ self.assertEqual(actual, '2 days and 2 hours')
def test_humanize_delta_handle_high_units(self):
"""humanize_delta should be able to handle very high units."""
# Very high maximum units, but it only ever iterates over
# each value the relativedelta might have.
- self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours')
+ actual = time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=20)
+ self.assertEqual(actual, '2 days and 2 hours')
def test_humanize_delta_should_normal_usage(self):
"""Testing humanize delta."""
@@ -32,7 +34,8 @@ class TimeTests(unittest.TestCase):
for delta, precision, max_units, expected in test_cases:
with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected):
- self.assertEqual(time.humanize_delta(delta, precision, max_units), expected)
+ actual = time.humanize_delta(delta, precision=precision, max_units=max_units)
+ self.assertEqual(actual, expected)
def test_humanize_delta_raises_for_invalid_max_units(self):
"""humanize_delta should raises ValueError('max_units must be positive') for invalid max_units."""
@@ -40,22 +43,11 @@ class TimeTests(unittest.TestCase):
for max_units in test_cases:
with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
- time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
- self.assertEqual(str(error.exception), 'max_units must be positive')
-
- def test_parse_rfc1123(self):
- """Testing parse_rfc1123."""
- self.assertEqual(
- time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'),
- datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)
- )
-
- def test_format_infraction(self):
- """Testing format_infraction."""
- self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01')
+ time.humanize_delta(relativedelta(days=2, hours=2), precision='hours', max_units=max_units)
+ self.assertEqual(str(error.exception), 'max_units must be positive.')
- def test_format_infraction_with_duration_none_expiry(self):
- """format_infraction_with_duration should work for None expiry."""
+ def test_format_with_duration_none_expiry(self):
+ """format_with_duration should work for None expiry."""
test_cases = (
(None, None, None, None),
@@ -67,82 +59,71 @@ class TimeTests(unittest.TestCase):
for expiry, date_from, max_units, expected in test_cases:
with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
- self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+ self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected)
- def test_format_infraction_with_duration_custom_units(self):
- """format_infraction_with_duration should work for custom max_units."""
+ def test_format_with_duration_custom_units(self):
+ """format_with_duration should work for custom max_units."""
test_cases = (
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6,
- '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'),
- ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20,
- '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)')
+ ('3000-12-12T00:01:00Z', datetime(3000, 12, 11, 12, 5, 5, tzinfo=timezone.utc), 6,
+ '<t:32533488060:f> (11 hours, 55 minutes and 55 seconds)'),
+ ('3000-11-23T20:09:00Z', datetime(3000, 4, 25, 20, 15, tzinfo=timezone.utc), 20,
+ '<t:32531918940:f> (6 months, 28 days, 23 hours and 54 minutes)')
)
for expiry, date_from, max_units, expected in test_cases:
with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
- self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+ self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected)
- def test_format_infraction_with_duration_normal_usage(self):
- """format_infraction_with_duration should work for normal usage, across various durations."""
+ def test_format_with_duration_normal_usage(self):
+ """format_with_duration should work for normal usage, across various durations."""
+ utc = timezone.utc
test_cases = (
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'),
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'),
- ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'),
- ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'),
- ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'),
- ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'),
- ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'),
- ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'),
- ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2,
- '2019-11-23 23:59 (9 minutes and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5, tzinfo=utc), 2,
+ '<t:1576108860:f> (12 hours and 55 seconds)'),
+ ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5, tzinfo=utc), 1, '<t:1576108860:f> (12 hours)'),
+ ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59, tzinfo=utc), 2, '<t:1576108800:f> (1 minute)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15, tzinfo=utc), 2,
+ '<t:1574539740:f> (7 days and 23 hours)'),
+ ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15, tzinfo=utc), 2,
+ '<t:1574539740:f> (6 months and 28 days)'),
+ ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53, tzinfo=utc), 2, '<t:1574542680:f> (5 minutes)'),
+ ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0, tzinfo=utc), 2, '<t:1574553600:f> (1 minute)'),
+ ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0, tzinfo=utc), 2,
+ '<t:1574553540:f> (2 years and 4 months)'),
+ ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5, tzinfo=utc), 2,
+ '<t:1574553540:f> (9 minutes and 55 seconds)'),
(None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
)
for expiry, date_from, max_units, expected in test_cases:
with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected):
- self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected)
+ self.assertEqual(time.format_with_duration(expiry, date_from, max_units), expected)
def test_until_expiration_with_duration_none_expiry(self):
- """until_expiration should work for None expiry."""
- test_cases = (
- (None, None, None, None),
-
- # To make sure that now and max_units are not touched
- (None, 'Why hello there!', None, None),
- (None, None, float('inf'), None),
- (None, 'Why hello there!', float('inf'), None),
- )
-
- for expiry, now, max_units, expected in test_cases:
- with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
- self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+ """until_expiration should return "Permanent" is expiry is None."""
+ self.assertEqual(time.until_expiration(None), "Permanent")
def test_until_expiration_with_duration_custom_units(self):
"""until_expiration should work for custom max_units."""
test_cases = (
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'),
- ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes')
+ ('3000-12-12T00:01:00Z', '<t:32533488060:R>'),
+ ('3000-11-23T20:09:00Z', '<t:32531918940:R>')
)
- for expiry, now, max_units, expected in test_cases:
- with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
- self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+ for expiry, expected in test_cases:
+ with self.subTest(expiry=expiry, expected=expected):
+ self.assertEqual(time.until_expiration(expiry,), expected)
def test_until_expiration_normal_usage(self):
"""until_expiration should work for normal usage, across various durations."""
test_cases = (
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'),
- ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'),
- ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'),
- ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'),
- ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'),
- ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'),
- ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'),
- ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'),
- ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'),
- (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),
+ ('3000-12-12T00:01:00Z', '<t:32533488060:R>'),
+ ('3000-12-12T00:01:00Z', '<t:32533488060:R>'),
+ ('3000-12-12T00:00:00Z', '<t:32533488000:R>'),
+ ('3000-11-23T20:09:00Z', '<t:32531918940:R>'),
+ ('3000-11-23T20:09:00Z', '<t:32531918940:R>'),
)
- for expiry, now, max_units, expected in test_cases:
- with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected):
- self.assertEqual(time.until_expiration(expiry, now, max_units), expected)
+ for expiry, expected in test_cases:
+ with self.subTest(expiry=expiry, expected=expected):
+ self.assertEqual(time.until_expiration(expiry), expected)
diff --git a/tests/helpers.py b/tests/helpers.py
index 3978076ed..9d4988d23 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -39,7 +39,7 @@ class HashableMixin(discord.mixins.EqualityComparable):
class ColourMixin:
- """A mixin for Mocks that provides the aliasing of color->colour like discord.py does."""
+ """A mixin for Mocks that provides the aliasing of (accent_)color->(accent_)colour like discord.py does."""
@property
def color(self) -> discord.Colour:
@@ -49,6 +49,14 @@ class ColourMixin:
def color(self, color: discord.Colour) -> None:
self.colour = color
+ @property
+ def accent_color(self) -> discord.Colour:
+ return self.accent_colour
+
+ @accent_color.setter
+ def accent_color(self, color: discord.Colour) -> None:
+ self.accent_colour = color
+
class CustomMockMixin:
"""
@@ -235,13 +243,20 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
self.roles.extend(roles)
+ self.top_role = max(self.roles)
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
# Create a User instance to get a realistic Mock of `discord.User`
-user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock())
+_user_data_mock = collections.defaultdict(unittest.mock.MagicMock, {
+ "accent_color": 0
+})
+user_instance = discord.User(
+ data=unittest.mock.MagicMock(get=unittest.mock.Mock(side_effect=_user_data_mock.get)),
+ state=unittest.mock.MagicMock()
+)
class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
@@ -278,7 +293,10 @@ def _get_mock_loop() -> unittest.mock.Mock:
# Since calling `create_task` on our MockBot does not actually schedule the coroutine object
# as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
# to prevent "has not been awaited"-warnings.
- loop.create_task.side_effect = lambda coroutine: coroutine.close()
+ def mock_create_task(coroutine, **kwargs):
+ coroutine.close()
+ return unittest.mock.Mock()
+ loop.create_task.side_effect = mock_create_task
return loop
@@ -424,7 +442,12 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da
# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
-context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+context_instance = Context(
+ message=unittest.mock.MagicMock(),
+ prefix="$",
+ bot=MockBot(),
+ view=None
+)
context_instance.invoked_from_error_handler = None
@@ -439,6 +462,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
+ self.me = kwargs.get('me', MockMember())
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -532,7 +556,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
self.__str__.return_value = str(self.emoji)
-webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())
+webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock())
class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
diff --git a/tests/test_base.py b/tests/test_base.py
index a7db4bf3e..365805a71 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -1,8 +1,7 @@
import logging
-import unittest
import unittest.mock
-
+from bot.log import get_logger
from tests.base import LoggingTestsMixin, _CaptureLogHandler
@@ -15,7 +14,7 @@ class LoggingTestCaseTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.log = logging.getLogger(__name__)
+ cls.log = get_logger(__name__)
def test_assert_not_logs_does_not_raise_with_no_logs(self):
"""Test if LoggingTestCase.assertNotLogs does not raise when no logs were emitted."""
@@ -56,15 +55,15 @@ class LoggingTestCaseTests(unittest.TestCase):
def test_logging_test_case_works_with_logger_instance(self):
"""Test if the LoggingTestCase captures logging for provided logger."""
- log = logging.getLogger("new_logger")
+ log = get_logger("new_logger")
with self.assertRaises(AssertionError):
with LoggingTestCase.assertNotLogs(self, logger=log):
log.info("Hello, this should raise an AssertionError")
def test_logging_test_case_respects_alternative_logger(self):
"""Test if LoggingTestCase only checks the provided logger."""
- log_one = logging.getLogger("log one")
- log_two = logging.getLogger("log two")
+ log_one = get_logger("log one")
+ log_two = get_logger("log two")
with LoggingTestCase.assertNotLogs(self, logger=log_one):
log_two.info("Hello, this should not raise an AssertionError")