aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2025-03-29 20:10:49 +0000
committerGravatar GitHub <[email protected]>2025-03-29 20:10:49 +0000
commit9806725c7f3e4c2fe87a0a429ec13c6c37b4df1f (patch)
tree63400f90c066e3ac5a045d58670f16a763ea872d /tests
parentReminders: Simplify helper function to get button embed (diff)
parentBump aiohttp from 3.11.12 to 3.11.13 (#3280) (diff)
Merge branch 'main' into feat/reminder-add-notify
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/exts/backend/sync/test_cog.py64
-rw-r--r--tests/bot/exts/backend/sync/test_users.py9
-rw-r--r--tests/bot/exts/backend/test_error_handler.py17
-rw-r--r--tests/bot/exts/filtering/test_extension_filter.py34
-rw-r--r--tests/bot/exts/info/doc/test_parsing.py18
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py20
-rw-r--r--tests/bot/exts/moderation/test_silence.py136
-rw-r--r--tests/bot/exts/recruitment/talentpool/test_review.py40
-rw-r--r--tests/bot/exts/utils/snekbox/test_snekbox.py51
-rw-r--r--tests/bot/test_constants.py4
-rw-r--r--tests/bot/utils/test_helpers.py113
-rw-r--r--tests/helpers.py34
12 files changed, 359 insertions, 181 deletions
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py
index 2ce950965..6d7356bf2 100644
--- a/tests/bot/exts/backend/sync/test_cog.py
+++ b/tests/bot/exts/backend/sync/test_cog.py
@@ -1,4 +1,6 @@
+import types
import unittest
+import unittest.mock
from unittest import mock
import discord
@@ -60,40 +62,54 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
class SyncCogTests(SyncCogTestCase):
"""Tests for the Sync cog."""
- async def test_sync_cog_sync_on_load(self):
- """Roles and users should be synced on cog load."""
- guild = helpers.MockGuild()
- self.bot.get_guild = mock.MagicMock(return_value=guild)
-
- self.RoleSyncer.reset_mock()
- self.UserSyncer.reset_mock()
-
- await self.cog.cog_load()
-
- self.RoleSyncer.sync.assert_called_once_with(guild)
- self.UserSyncer.sync.assert_called_once_with(guild)
-
- async def test_sync_cog_sync_guild(self):
- """Roles and users should be synced only if a guild is successfully retrieved."""
+ @unittest.mock.patch("bot.exts.backend.sync._cog.create_task", new_callable=unittest.mock.MagicMock)
+ async def test_sync_cog_sync_on_load(self, mock_create_task: unittest.mock.MagicMock):
+ """Sync function should be synced on cog load only if guild is found."""
for guild in (helpers.MockGuild(), None):
with self.subTest(guild=guild):
+ mock_create_task.reset_mock()
self.bot.reset_mock()
self.RoleSyncer.reset_mock()
self.UserSyncer.reset_mock()
self.bot.get_guild = mock.MagicMock(return_value=guild)
-
- await self.cog.cog_load()
-
- self.bot.wait_until_guild_available.assert_called_once()
- self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+ error_raised = False
+ try:
+ await self.cog.cog_load()
+ except ValueError:
+ if guild is None:
+ error_raised = True
+ else:
+ raise
if guild is None:
- self.RoleSyncer.sync.assert_not_called()
- self.UserSyncer.sync.assert_not_called()
+ self.assertTrue(error_raised)
+ mock_create_task.assert_not_called()
else:
- self.RoleSyncer.sync.assert_called_once_with(guild)
- self.UserSyncer.sync.assert_called_once_with(guild)
+ mock_create_task.assert_called_once()
+ create_task_arg = mock_create_task.call_args[0][0]
+ self.assertIsInstance(create_task_arg, types.CoroutineType)
+ self.assertEqual(create_task_arg.__qualname__, self.cog.sync.__qualname__)
+ create_task_arg.close()
+
+ async def test_sync_cog_sync_guild(self):
+ """Roles and users should be synced only if a guild is successfully retrieved."""
+ guild = helpers.MockGuild()
+ self.bot.reset_mock()
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
+
+ self.bot.get_guild = mock.MagicMock(return_value=guild)
+ await self.cog.cog_load()
+
+ with mock.patch("asyncio.sleep", new_callable=unittest.mock.AsyncMock):
+ await self.cog.sync()
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+
+ self.RoleSyncer.sync.assert_called_once()
+ self.UserSyncer.sync.assert_called_once()
async def patch_user_helper(self, side_effect: BaseException) -> None:
"""Helper to set a side effect for bot.api_client.patch and then assert it is called."""
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index 2fc97af2d..26bf0d9a0 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -11,6 +11,7 @@ def fake_user(**kwargs):
"""Fixture to return a dictionary representing a user with default values set."""
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
+ kwargs.setdefault("display_name", "bob")
kwargs.setdefault("discriminator", 1337)
kwargs.setdefault("roles", [helpers.MockRole(id=666)])
kwargs.setdefault("in_guild", True)
@@ -209,8 +210,8 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
diff = _Diff(self.users, [], None)
await UserSyncer._sync(diff)
- self.bot.api_client.post.assert_any_call("bot/users", json=diff.created[:self.chunk_size])
- self.bot.api_client.post.assert_any_call("bot/users", json=diff.created[self.chunk_size:])
+ self.bot.api_client.post.assert_any_call("bot/users", json=tuple(diff.created[:self.chunk_size]))
+ self.bot.api_client.post.assert_any_call("bot/users", json=tuple(diff.created[self.chunk_size:]))
self.assertEqual(self.bot.api_client.post.call_count, self.chunk_count)
self.bot.api_client.put.assert_not_called()
@@ -221,8 +222,8 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
diff = _Diff([], self.users, None)
await UserSyncer._sync(diff)
- self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=diff.updated[:self.chunk_size])
- self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=diff.updated[self.chunk_size:])
+ self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=tuple(diff.updated[:self.chunk_size]))
+ self.bot.api_client.patch.assert_any_call("bot/users/bulk_patch", json=tuple(diff.updated[self.chunk_size:]))
self.assertEqual(self.bot.api_client.patch.call_count, self.chunk_count)
self.bot.api_client.post.assert_not_called()
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 9670d42a0..85dc33999 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -414,12 +414,13 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for case in test_cases:
with self.subTest(error=case["error"], call_prepared=case["call_prepared"]):
self.ctx.reset_mock()
+ self.cog.send_error_with_help = AsyncMock()
self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"]))
- self.ctx.send.assert_awaited_once()
if case["call_prepared"]:
- self.ctx.send_help.assert_awaited_once()
+ self.cog.send_error_with_help.assert_awaited_once()
else:
- self.ctx.send_help.assert_not_awaited()
+ self.ctx.send.assert_awaited_once()
+ self.cog.send_error_with_help.assert_not_awaited()
async def test_handle_check_failure_errors(self):
"""Should await `ctx.send` when error is check failure."""
@@ -494,26 +495,26 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
else:
log_mock.debug.assert_called_once()
- @patch("bot.exts.backend.error_handler.push_scope")
+ @patch("bot.exts.backend.error_handler.new_scope")
@patch("bot.exts.backend.error_handler.log")
- async def test_handle_unexpected_error(self, log_mock, push_scope_mock):
+ async def test_handle_unexpected_error(self, log_mock, new_scope_mock):
"""Should `ctx.send` this error, error log this and sent to Sentry."""
for case in (None, MockGuild()):
with self.subTest(guild=case):
self.ctx.reset_mock()
log_mock.reset_mock()
- push_scope_mock.reset_mock()
+ new_scope_mock.reset_mock()
scope_mock = Mock()
# Mock `with push_scope_mock() as scope:`
- push_scope_mock.return_value.__enter__.return_value = scope_mock
+ new_scope_mock.return_value.__enter__.return_value = scope_mock
self.ctx.guild = case
await self.cog.handle_unexpected_error(self.ctx, errors.CommandError())
self.ctx.send.assert_awaited_once()
log_mock.error.assert_called_once()
- push_scope_mock.assert_called_once()
+ new_scope_mock.assert_called_once()
set_tag_calls = [
call("command", self.ctx.command.qualified_name),
diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py
index f71de1e1b..67a503b30 100644
--- a/tests/bot/exts/filtering/test_extension_filter.py
+++ b/tests/bot/exts/filtering/test_extension_filter.py
@@ -69,40 +69,6 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []}))
@patch("bot.instance", BOT)
- async def test_python_file_redirect_embed_description(self):
- """A message containing a .py file should result in an embed redirecting the user to our paste site."""
- attachment = MockAttachment(filename="python.py")
- ctx = self.ctx.replace(attachments=[attachment])
-
- await self.filter_list.actions_for(ctx)
-
- self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION)
-
- @patch("bot.instance", BOT)
- async def test_txt_file_redirect_embed_description(self):
- """A message containing a .txt/.json/.csv file should result in the correct embed."""
- test_values = (
- ("text", ".txt"),
- ("json", ".json"),
- ("csv", ".csv"),
- )
-
- for file_name, disallowed_extension in test_values:
- with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
-
- attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
- ctx = self.ctx.replace(attachments=[attachment])
-
- await self.filter_list.actions_for(ctx)
-
- self.assertEqual(
- ctx.dm_embed,
- extension.TXT_EMBED_DESCRIPTION.format(
- blocked_extension=disallowed_extension,
- )
- )
-
- @patch("bot.instance", BOT)
async def test_other_disallowed_extension_embed_description(self):
"""Test the description for a non .py/.txt/.json/.csv disallowed extension."""
attachment = MockAttachment(filename="python.disallowed")
diff --git a/tests/bot/exts/info/doc/test_parsing.py b/tests/bot/exts/info/doc/test_parsing.py
index d2105a53c..7136fc32c 100644
--- a/tests/bot/exts/info/doc/test_parsing.py
+++ b/tests/bot/exts/info/doc/test_parsing.py
@@ -1,5 +1,7 @@
from unittest import TestCase
+from bs4 import BeautifulSoup
+
from bot.exts.info.doc import _parsing as parsing
from bot.exts.info.doc._markdown import DocMarkdownConverter
@@ -87,3 +89,19 @@ class MarkdownConverterTest(TestCase):
with self.subTest(input_string=input_string):
d = DocMarkdownConverter(page_url="https://example.com")
self.assertEqual(d.convert(input_string), expected_output)
+
+
+class MarkdownCreationTest(TestCase):
+ def test_surrounding_whitespace(self):
+ test_cases = (
+ ("<p>Hello World</p>", "Hello World"),
+ ("<p>Hello</p><p>World</p>", "Hello\n\nWorld"),
+ ("<h1>Title</h1>", "**Title**")
+ )
+ self._run_tests(test_cases)
+
+ def _run_tests(self, test_cases: tuple[tuple[str, str], ...]):
+ for input_string, expected_output in test_cases:
+ with self.subTest(input_string=input_string):
+ tags = BeautifulSoup(input_string, "html.parser")
+ self.assertEqual(parsing._create_markdown(None, tags, "https://example.com"), expected_output)
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index 26ba770dc..f257bec7d 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -37,7 +37,9 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.mod_log.ignore = Mock()
self.ctx.guild.ban = AsyncMock()
- await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000)
+ infraction_reason = "foo bar" * 3000
+
+ await self.cog.apply_ban(self.ctx, self.target, infraction_reason)
self.cog.apply_infraction.assert_awaited_once_with(
self.ctx, {"foo": "bar", "purge": ""}, self.target, ANY
)
@@ -46,10 +48,14 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
await action()
self.ctx.guild.ban.assert_awaited_once_with(
self.target,
- reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),
+ reason=textwrap.shorten(infraction_reason, 512, placeholder="..."),
delete_message_days=0
)
+ # Assert that the reason sent to the database isn't truncated.
+ post_infraction_mock.assert_awaited_once()
+ self.assertEqual(post_infraction_mock.call_args.args[3], infraction_reason)
+
@patch("bot.exts.moderation.infraction._utils.post_infraction")
async def test_apply_kick_reason_truncation(self, post_infraction_mock):
"""Should truncate reason for `Member.kick`."""
@@ -59,14 +65,20 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.mod_log.ignore = Mock()
self.target.kick = AsyncMock()
- await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000)
+ infraction_reason = "foo bar" * 3000
+
+ await self.cog.apply_kick(self.ctx, self.target, infraction_reason)
self.cog.apply_infraction.assert_awaited_once_with(
self.ctx, {"foo": "bar"}, self.target, ANY
)
action = self.cog.apply_infraction.call_args.args[-1]
await action()
- self.target.kick.assert_awaited_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
+ self.target.kick.assert_awaited_once_with(reason=textwrap.shorten(infraction_reason, 512, placeholder="..."))
+
+ # Assert that the reason sent to the database isn't truncated.
+ post_infraction_mock.assert_awaited_once()
+ self.assertEqual(post_infraction_mock.call_args.args[3], infraction_reason)
@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index a7f239d7f..86d396afd 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -37,15 +37,13 @@ class SilenceTest(RedisTestCase):
self.bot = MockBot(get_channel=lambda _id: MockTextChannel(id=_id))
self.cog = silence.Silence(self.bot)
- @autospec(silence, "SilenceNotifier", pass_mocks=False)
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
await self.cog.cog_load() # Populate instance attributes.
-class SilenceNotifierTests(SilenceTest):
+class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
- super().setUp()
self.alert_channel = MockTextChannel()
self.notifier = silence.SilenceNotifier(self.alert_channel)
self.notifier.stop = self.notifier_stop_mock = Mock()
@@ -54,32 +52,36 @@ class SilenceNotifierTests(SilenceTest):
def test_add_channel_adds_channel(self):
"""Channel is added to `_silenced_channels` with the current loop."""
channel = Mock()
- with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
- self.notifier.add_channel(channel)
- silenced_channels.__setitem__.assert_called_with(channel, self.notifier._current_loop)
+ self.notifier.add_channel(channel)
+ self.assertDictEqual(self.notifier._silenced_channels, {channel: self.notifier._current_loop})
+
+ def test_add_channel_loop_called_correctly(self):
+ """Loop is called only in correct scenarios."""
- def test_add_channel_starts_loop(self):
- """Loop is started if `_silenced_channels` was empty."""
+ # Loop is started if `_silenced_channels` was empty.
self.notifier.add_channel(Mock())
self.notifier_start_mock.assert_called_once()
- def test_add_channel_skips_start_with_channels(self):
- """Loop start is not called when `_silenced_channels` is not empty."""
- with mock.patch.object(self.notifier, "_silenced_channels"):
- self.notifier.add_channel(Mock())
+ self.notifier_start_mock.reset_mock()
+
+ # Loop start is not called when `_silenced_channels` is not empty.
+ self.notifier.add_channel(Mock())
self.notifier_start_mock.assert_not_called()
def test_remove_channel_removes_channel(self):
"""Channel is removed from `_silenced_channels`."""
channel = Mock()
- with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
- self.notifier.remove_channel(channel)
- silenced_channels.__delitem__.assert_called_with(channel)
+ self.notifier.add_channel(channel)
+ self.notifier.remove_channel(channel)
+ self.assertDictEqual(self.notifier._silenced_channels, {})
def test_remove_channel_stops_loop(self):
"""Notifier loop is stopped if `_silenced_channels` is empty after remove."""
- with mock.patch.object(self.notifier, "_silenced_channels", __bool__=lambda _: False):
- self.notifier.remove_channel(Mock())
+ channel = Mock()
+ self.notifier.add_channel(channel)
+ self.notifier_stop_mock.assert_not_called()
+
+ self.notifier.remove_channel(channel)
self.notifier_stop_mock.assert_called_once()
def test_remove_channel_skips_stop_with_channels(self):
@@ -111,33 +113,28 @@ class SilenceNotifierTests(SilenceTest):
self.alert_channel.send.assert_not_called()
-@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class SilenceCogTests(SilenceTest):
"""Tests for the general functionality of the Silence cog."""
- @autospec(silence, "SilenceNotifier", pass_mocks=False)
async def test_cog_load_got_guild(self):
"""Bot got guild after it became available."""
self.bot.wait_until_guild_available.assert_awaited_once()
self.bot.get_guild.assert_called_once_with(Guild.id)
- @autospec(silence, "SilenceNotifier", pass_mocks=False)
async def test_cog_load_got_channels(self):
"""Got channels from bot."""
- await self.cog.cog_load()
self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)
- @autospec(silence, "SilenceNotifier")
- async def test_cog_load_got_notifier(self, notifier):
+ async def test_cog_load_got_notifier(self):
"""Notifier was started with channel."""
- await self.cog.cog_load()
+ with mock.patch.object(silence, "SilenceNotifier") as notifier:
+ await self.cog.cog_load()
notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))
self.assertEqual(self.cog.notifier, notifier.return_value)
- @autospec(silence, "SilenceNotifier", pass_mocks=False)
async def testcog_load_rescheduled(self):
"""`_reschedule_` coroutine was awaited."""
- self.cog._reschedule = mock.create_autospec(self.cog._reschedule)
+ self.cog._reschedule = AsyncMock()
await self.cog.cog_load()
self.cog._reschedule.assert_awaited_once_with()
@@ -242,7 +239,7 @@ class SilenceCogTests(SilenceTest):
self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2)
-class SilenceArgumentParserTests(SilenceTest):
+class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the silence argument parser utility function."""
@autospec(silence.Silence, "send_message", pass_mocks=False)
@@ -250,6 +247,9 @@ class SilenceArgumentParserTests(SilenceTest):
@autospec(silence.Silence, "parse_silence_args")
async def test_command(self, parser_mock):
"""Test that the command passes in the correct arguments for different calls."""
+ bot = MockBot()
+ cog = silence.Silence(bot)
+
test_cases = (
(),
(15, ),
@@ -262,7 +262,7 @@ class SilenceArgumentParserTests(SilenceTest):
for case in test_cases:
with self.subTest("Test command converters", args=case):
- await self.cog.silence.callback(self.cog, ctx, *case)
+ await cog.silence.callback(cog, ctx, *case)
try:
first_arg = case[0]
@@ -281,7 +281,7 @@ class SilenceArgumentParserTests(SilenceTest):
async def test_no_arguments(self):
"""Test the parser when no arguments are passed to the command."""
ctx = MockContext()
- channel, duration = self.cog.parse_silence_args(ctx, None, 10)
+ channel, duration = silence.Silence.parse_silence_args(ctx, None, 10)
self.assertEqual(ctx.channel, channel)
self.assertEqual(10, duration)
@@ -289,7 +289,7 @@ class SilenceArgumentParserTests(SilenceTest):
async def test_channel_only(self):
"""Test the parser when just the channel argument is passed."""
expected_channel = MockTextChannel()
- actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 10)
+ actual_channel, duration = silence.Silence.parse_silence_args(MockContext(), expected_channel, 10)
self.assertEqual(expected_channel, actual_channel)
self.assertEqual(10, duration)
@@ -297,7 +297,7 @@ class SilenceArgumentParserTests(SilenceTest):
async def test_duration_only(self):
"""Test the parser when just the duration argument is passed."""
ctx = MockContext()
- channel, duration = self.cog.parse_silence_args(ctx, 15, 10)
+ channel, duration = silence.Silence.parse_silence_args(ctx, 15, 10)
self.assertEqual(ctx.channel, channel)
self.assertEqual(15, duration)
@@ -305,13 +305,12 @@ class SilenceArgumentParserTests(SilenceTest):
async def test_all_args(self):
"""Test the parser when both channel and duration are passed."""
expected_channel = MockTextChannel()
- actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 15)
+ actual_channel, duration = silence.Silence.parse_silence_args(MockContext(), expected_channel, 15)
self.assertEqual(expected_channel, actual_channel)
self.assertEqual(15, duration)
-@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class RescheduleTests(RedisTestCase):
"""Tests for the rescheduling of cached unsilences."""
@@ -328,7 +327,7 @@ class RescheduleTests(RedisTestCase):
async def test_skipped_missing_channel(self):
"""Did nothing because the channel couldn't be retrieved."""
- self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)]
+ await self.cog.unsilence_timestamps.set(123, -1)
self.bot.get_channel.return_value = None
await self.cog._reschedule()
@@ -341,8 +340,8 @@ class RescheduleTests(RedisTestCase):
"""Permanently silenced channels were added to the notifier."""
channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
self.bot.get_channel.side_effect = channels
- self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)]
-
+ await self.cog.unsilence_timestamps.set(123, -1)
+ await self.cog.unsilence_timestamps.set(456, -1)
await self.cog._reschedule()
self.cog.notifier.add_channel.assert_any_call(channels[0])
@@ -355,7 +354,8 @@ class RescheduleTests(RedisTestCase):
"""Unsilenced expired silences."""
channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
self.bot.get_channel.side_effect = channels
- self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)]
+ await self.cog.unsilence_timestamps.set(123, 100)
+ await self.cog.unsilence_timestamps.set(456, 200)
await self.cog._reschedule()
@@ -370,7 +370,8 @@ class RescheduleTests(RedisTestCase):
"""Rescheduled active silences."""
channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
self.bot.get_channel.side_effect = channels
- self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)]
+ await self.cog.unsilence_timestamps.set(123, 2000)
+ await self.cog.unsilence_timestamps.set(456, 3000)
silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=UTC)
self.cog._unsilence_wrapper = mock.MagicMock()
@@ -398,7 +399,6 @@ def voice_sync_helper(function):
return inner
-@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class SilenceTests(SilenceTest):
"""Tests for the silence command and its related helper methods."""
@@ -596,19 +596,28 @@ class SilenceTests(SilenceTest):
async def test_temp_not_added_to_notifier(self):
"""Channel was not added to notifier if a duration was set for the silence."""
- with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ with (
+ mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True),
+ mock.patch.object(self.cog.notifier, "add_channel")
+ ):
await self.cog.silence.callback(self.cog, MockContext(), 15)
self.cog.notifier.add_channel.assert_not_called()
async def test_indefinite_added_to_notifier(self):
"""Channel was added to notifier if a duration was not set for the silence."""
- with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ with (
+ mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True),
+ mock.patch.object(self.cog.notifier, "add_channel")
+ ):
await self.cog.silence.callback(self.cog, MockContext(), None, None)
self.cog.notifier.add_channel.assert_called_once()
async def test_silenced_not_added_to_notifier(self):
"""Channel was not added to the notifier if it was already silenced."""
- with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False):
+ with (
+ mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False),
+ mock.patch.object(self.cog.notifier, "add_channel")
+ ):
await self.cog.silence.callback(self.cog, MockContext(), 15)
self.cog.notifier.add_channel.assert_not_called()
@@ -619,7 +628,7 @@ class SilenceTests(SilenceTest):
'"create_public_threads": false, "send_messages_in_threads": true}'
)
await self.cog._set_silence_overwrites(self.text_channel)
- self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json)
+ self.assertEqual(await self.cog.previous_overwrites.get(self.text_channel.id), overwrite_json)
@autospec(silence, "datetime")
async def test_cached_unsilence_time(self, datetime_mock):
@@ -632,14 +641,14 @@ class SilenceTests(SilenceTest):
ctx = MockContext(channel=self.text_channel)
await self.cog.silence.callback(self.cog, ctx, duration)
- self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp)
+ self.assertEqual(await self.cog.unsilence_timestamps.get(ctx.channel.id), timestamp)
datetime_mock.now.assert_called_once_with(tz=UTC) # Ensure it's using an aware dt.
async def test_cached_indefinite_time(self):
"""A value of -1 was cached for a permanent silence."""
ctx = MockContext(channel=self.text_channel)
await self.cog.silence.callback(self.cog, ctx, None, None)
- self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
+ self.assertEqual(await self.cog.unsilence_timestamps.get(ctx.channel.id), -1)
async def test_scheduled_task(self):
"""An unsilence task was scheduled."""
@@ -665,7 +674,6 @@ class SilenceTests(SilenceTest):
unsilence.assert_awaited_once_with(ctx, ctx.channel, None)
-@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)
class UnsilenceTests(SilenceTest):
"""Tests for the unsilence command and its related helper methods."""
@@ -681,13 +689,6 @@ class UnsilenceTests(SilenceTest):
self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)
self.voice_channel.overwrites_for.return_value = self.voice_overwrite
- async def asyncSetUp(self) -> None:
- await super().asyncSetUp()
- overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
- self.cog.previous_overwrites = overwrites_cache
-
- overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
-
async def test_sent_correct_message(self):
"""Appropriate failure/success message was sent by the command."""
unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True)
@@ -720,7 +721,6 @@ class UnsilenceTests(SilenceTest):
async def test_skipped_already_unsilenced(self):
"""Permissions were not set and `False` was returned for an already unsilenced channel."""
self.cog.scheduler.__contains__.return_value = False
- self.cog.previous_overwrites.get.return_value = None
for channel in (MockVoiceChannel(), MockTextChannel()):
with self.subTest(channel=channel):
@@ -729,6 +729,7 @@ class UnsilenceTests(SilenceTest):
async def test_restored_overwrites_text(self):
"""Text channel's `send_message` and `add_reactions` overwrites were restored."""
+ await self.cog.previous_overwrites.set(self.text_channel.id, '{"send_messages": true, "add_reactions": false}')
await self.cog._unsilence(self.text_channel)
self.text_channel.set_permissions.assert_awaited_once_with(
self.cog._everyone_role,
@@ -741,19 +742,18 @@ class UnsilenceTests(SilenceTest):
async def test_restored_overwrites_voice(self):
"""Voice channel's `connect` and `speak` overwrites were restored."""
+ await self.cog.previous_overwrites.set(self.voice_channel.id, '{"connect": true, "speak": true}')
await self.cog._unsilence(self.voice_channel)
self.voice_channel.set_permissions.assert_awaited_once_with(
self.cog._verified_voice_role,
overwrite=self.voice_overwrite,
)
- # Recall that these values are determined by the fixture.
self.assertTrue(self.voice_overwrite.connect)
self.assertTrue(self.voice_overwrite.speak)
async def test_cache_miss_used_default_overwrites_text(self):
"""Text overwrites were set to None due previous values not being found in the cache."""
- self.cog.previous_overwrites.get.return_value = None
await self.cog._unsilence(self.text_channel)
self.text_channel.set_permissions.assert_awaited_once_with(
@@ -766,7 +766,6 @@ class UnsilenceTests(SilenceTest):
async def test_cache_miss_used_default_overwrites_voice(self):
"""Voice overwrites were set to None due previous values not being found in the cache."""
- self.cog.previous_overwrites.get.return_value = None
await self.cog._unsilence(self.voice_channel)
self.voice_channel.set_permissions.assert_awaited_once_with(
@@ -779,30 +778,31 @@ class UnsilenceTests(SilenceTest):
async def test_cache_miss_sent_mod_alert_text(self):
"""A message was sent to the mod alerts channel upon muting a text channel."""
- self.cog.previous_overwrites.get.return_value = None
await self.cog._unsilence(self.text_channel)
self.cog._mod_alerts_channel.send.assert_awaited_once()
async def test_cache_miss_sent_mod_alert_voice(self):
"""A message was sent to the mod alerts channel upon muting a voice channel."""
- self.cog.previous_overwrites.get.return_value = None
await self.cog._unsilence(MockVoiceChannel())
self.cog._mod_alerts_channel.send.assert_awaited_once()
async def test_removed_notifier(self):
"""Channel was removed from `notifier`."""
- await self.cog._unsilence(self.text_channel)
- self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel)
+ with mock.patch.object(silence.SilenceNotifier, "remove_channel"):
+ await self.cog._unsilence(self.text_channel)
+ self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel)
async def test_deleted_cached_overwrite(self):
"""Channel was deleted from the overwrites cache."""
+ await self.cog.previous_overwrites.set(self.text_channel.id, '{"send_messages": true, "add_reactions": false}')
await self.cog._unsilence(self.text_channel)
- self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id)
+ self.assertEqual(await self.cog.previous_overwrites.get(self.text_channel.id), None)
async def test_deleted_cached_time(self):
"""Channel was deleted from the timestamp cache."""
+ await self.cog.unsilence_timestamps.set(self.text_channel.id, 100)
await self.cog._unsilence(self.text_channel)
- self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id)
+ self.assertEqual(await self.cog.unsilence_timestamps.get(self.text_channel.id), None)
async def test_cancelled_task(self):
"""The scheduled unsilence task should be cancelled."""
@@ -813,7 +813,10 @@ class UnsilenceTests(SilenceTest):
"""Text channel's other unrelated overwrites were not changed, including cache misses."""
for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):
with self.subTest(overwrite_json=overwrite_json):
- self.cog.previous_overwrites.get.return_value = overwrite_json
+ if overwrite_json is None:
+ await self.cog.previous_overwrites.delete(self.text_channel.id)
+ else:
+ await self.cog.previous_overwrites.set(self.text_channel.id, overwrite_json)
prev_overwrite_dict = dict(self.text_overwrite)
await self.cog._unsilence(self.text_channel)
@@ -831,7 +834,10 @@ class UnsilenceTests(SilenceTest):
"""Voice channel's other unrelated overwrites were not changed, including cache misses."""
for overwrite_json in ('{"connect": true, "speak": true}', None):
with self.subTest(overwrite_json=overwrite_json):
- self.cog.previous_overwrites.get.return_value = overwrite_json
+ if overwrite_json is None:
+ await self.cog.previous_overwrites.delete(self.voice_channel.id)
+ else:
+ await self.cog.previous_overwrites.set(self.voice_channel.id, overwrite_json)
prev_overwrite_dict = dict(self.voice_overwrite)
await self.cog._unsilence(self.voice_channel)
diff --git a/tests/bot/exts/recruitment/talentpool/test_review.py b/tests/bot/exts/recruitment/talentpool/test_review.py
index 25622e91f..8ec384bb2 100644
--- a/tests/bot/exts/recruitment/talentpool/test_review.py
+++ b/tests/bot/exts/recruitment/talentpool/test_review.py
@@ -3,7 +3,7 @@ from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, Mock, patch
from bot.exts.recruitment.talentpool import _review
-from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel
+from tests.helpers import MockBot, MockMember, MockMessage, MockReaction, MockTextChannel
class AsyncIterator:
@@ -63,8 +63,10 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the `is_ready_for_review` function."""
too_recent = datetime.now(UTC) - timedelta(hours=1)
not_too_recent = datetime.now(UTC) - timedelta(days=7)
+ ticket_reaction = MockReaction(users=[self.bot_user], emoji="\N{TICKET}")
+
cases = (
- # Only one review, and not too recent, so ready.
+ # Only one active review, and not too recent, so ready.
(
[
MockMessage(author=self.bot_user, content="wookie for Helper!", created_at=not_too_recent),
@@ -75,7 +77,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase):
True,
),
- # Three reviews, so not ready.
+ # Three active reviews, so not ready.
(
[
MockMessage(author=self.bot_user, content="Chrisjl for Helper!", created_at=not_too_recent),
@@ -86,7 +88,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase):
False,
),
- # Only one review, but too recent, so not ready.
+ # Only one active review, but too recent, so not ready.
(
[
MockMessage(author=self.bot_user, content="Chrisjl for Helper!", created_at=too_recent),
@@ -95,7 +97,7 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase):
False,
),
- # Only two reviews, and not too recent, so ready.
+ # Only two active reviews, and not too recent, so ready.
(
[
MockMessage(author=self.bot_user, content="Not a review", created_at=too_recent),
@@ -107,6 +109,34 @@ class ReviewerTests(unittest.IsolatedAsyncioTestCase):
True,
),
+ # Over the active threshold, but below the total threshold
+ (
+ [
+ MockMessage(
+ author=self.bot_user,
+ content="joe for Helper!",
+ created_at=not_too_recent,
+ reactions=[ticket_reaction]
+ )
+ ] * 6,
+ not_too_recent.timestamp(),
+ True
+ ),
+
+ # Over the total threshold
+ (
+ [
+ MockMessage(
+ author=self.bot_user,
+ content="joe for Helper!",
+ created_at=not_too_recent,
+ reactions=[ticket_reaction]
+ )
+ ] * 11,
+ not_too_recent.timestamp(),
+ False
+ ),
+
# No messages, so ready.
([], None, True),
)
diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py
index 8ee0f46ff..9cfd75df8 100644
--- a/tests/bot/exts/utils/snekbox/test_snekbox.py
+++ b/tests/bot/exts/utils/snekbox/test_snekbox.py
@@ -35,8 +35,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
context_manager = MagicMock()
context_manager.__aenter__.return_value = resp
self.bot.http_session.post.return_value = context_manager
-
- job = EvalJob.from_code("import random").as_version("3.10")
+ py_version = "3.12"
+ job = EvalJob.from_code("import random").as_version(py_version)
self.assertEqual(await self.cog.post_job(job), EvalResult("Hi", 137))
expected = {
@@ -44,9 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"files": [
{
"path": "main.py",
- "content": b64encode(b"import random").decode()
+ "content": b64encode(b"import random").decode(),
}
- ]
+ ],
+ "executable_path": f"/snekbin/python/{py_version}/bin/python",
}
self.bot.http_session.post.assert_called_with(
constants.URLs.snekbox_eval_api,
@@ -113,7 +114,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
result = EvalResult(stdout=stdout, returncode=returncode)
job = EvalJob([])
# Check all 3 message types
- msg = result.get_message(job)
+ msg = result.get_status_message(job)
self.assertEqual(msg, exp_msg)
error = result.error_message
self.assertEqual(error, exp_err)
@@ -166,7 +167,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def test_eval_result_message_invalid_signal(self, _mock_signals: Mock):
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.10")),
+ result.get_status_message(EvalJob([], version="3.10")),
"Your 3.10 eval job has completed with return code 127"
)
self.assertEqual(result.error_message, "")
@@ -177,7 +178,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mock_signals.return_value.name = "SIGTEST"
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.12")),
+ result.get_status_message(EvalJob([], version="3.12")),
"Your 3.12 eval job has completed with return code 127 (SIGTEST)"
)
@@ -199,7 +200,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
too_many_lines = (
"001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n"
- "007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)"
+ "007 | n\n008 | g\n009 | b\n010 | e\n... (truncated - too many lines)"
)
too_long_too_many_lines = (
"\n".join(
@@ -292,7 +293,6 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_send_job(self):
"""Test the send_job function."""
ctx = MockContext()
- ctx.message = MockMessage()
ctx.send = AsyncMock()
ctx.author = MockUser(mention="@LemonLemonishBeard#0042")
@@ -311,8 +311,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- "@LemonLemonishBeard#0042 :warning: Your 3.12 eval job has completed "
- "with return code 0.\n\n```\n[No output]\n```"
+ ":warning: Your 3.12 eval job has completed "
+ "with return code 0.\n\n```ansi\n[No output]\n```"
)
allowed_mentions = ctx.send.call_args.kwargs["allowed_mentions"]
expected_allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author])
@@ -325,9 +325,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_send_job_with_paste_link(self):
"""Test the send_job function with a too long output that generate a paste link."""
ctx = MockContext()
- ctx.message = MockMessage()
ctx.send = AsyncMock()
- ctx.author.mention = "@LemonLemonishBeard#0042"
eval_result = EvalResult("Way too long beard", 0)
self.cog.post_job = AsyncMock(return_value=eval_result)
@@ -343,9 +341,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- "@LemonLemonishBeard#0042 :white_check_mark: Your 3.12 eval job "
+ ":white_check_mark: Your 3.12 eval job "
"has completed with return code 0."
- "\n\n```\nWay too long beard\n```\nFull output: lookatmybeard.com"
+ "\n\n```ansi\nWay too long beard\n```\nFull output: lookatmybeard.com"
)
self.cog.post_job.assert_called_once_with(job)
@@ -354,9 +352,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_send_job_with_non_zero_eval(self):
"""Test the send_job function with a code returning a non-zero code."""
ctx = MockContext()
- ctx.message = MockMessage()
ctx.send = AsyncMock()
- ctx.author.mention = "@LemonLemonishBeard#0042"
eval_result = EvalResult("ERROR", 127)
self.cog.post_job = AsyncMock(return_value=eval_result)
@@ -372,8 +368,8 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send.assert_called_once()
self.assertEqual(
ctx.send.call_args.args[0],
- "@LemonLemonishBeard#0042 :x: Your 3.12 eval job has completed with return code 127."
- "\n\n```\nERROR\n```"
+ ":x: Your 3.12 eval job has completed with return code 127."
+ "\n\n```ansi\nERROR\n```"
)
self.cog.post_job.assert_called_once_with(job)
@@ -382,16 +378,21 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_send_job_with_disallowed_file_ext(self):
"""Test send_job with disallowed file extensions."""
ctx = MockContext()
- ctx.message = MockMessage()
ctx.send = AsyncMock()
- ctx.author.mention = "@user#7700"
- eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")])
+ files = [
+ FileAttachment("test.disallowed2", b"test"),
+ FileAttachment("test.disallowed", b"test"),
+ FileAttachment("test.allowed", b"test"),
+ FileAttachment("test." + ("a" * 100), b"test")
+ ]
+ eval_result = EvalResult("", 0, files=files)
self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.upload_output = AsyncMock() # This function isn't called
+ disallowed_exts = [".disallowed", "." + ("a" * 100), ".disallowed2"]
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, disallowed_exts))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.12")
@@ -400,9 +401,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send.assert_called_once()
res = ctx.send.call_args.args[0]
self.assertTrue(
- res.startswith("@user#7700 :white_check_mark: Your 3.12 eval job has completed with return code 0.")
+ res.startswith(":white_check_mark: Your 3.12 eval job has completed with return code 0.")
)
- self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res)
+ self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed, .disallowed2, ...**", res)
self.cog.post_job.assert_called_once_with(job)
self.cog.upload_output.assert_not_called()
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
index 87933d59a..916e1d5bb 100644
--- a/tests/bot/test_constants.py
+++ b/tests/bot/test_constants.py
@@ -10,7 +10,7 @@ current_path = Path(__file__)
env_file_path = current_path.parent / ".testenv"
-class TestEnvConfig(
+class _TestEnvConfig(
EnvConfig,
env_file=env_file_path,
):
@@ -21,7 +21,7 @@ class NestedModel(BaseModel):
server_name: str
-class _TestConfig(TestEnvConfig, env_prefix="unittests_"):
+class _TestConfig(_TestEnvConfig, env_prefix="unittests_"):
goat: str
execution_env: str = "local"
diff --git a/tests/bot/utils/test_helpers.py b/tests/bot/utils/test_helpers.py
new file mode 100644
index 000000000..e8ab6ba80
--- /dev/null
+++ b/tests/bot/utils/test_helpers.py
@@ -0,0 +1,113 @@
+import unittest
+
+from bot.utils import helpers
+
+
+class TestHelpers(unittest.TestCase):
+ """Tests for the helper functions in the `bot.utils.helpers` module."""
+
+ def test_find_nth_occurrence_returns_index(self):
+ """Test if `find_nth_occurrence` returns the index correctly when substring is found."""
+ test_values = (
+ ("hello", "l", 1, 2),
+ ("hello", "l", 2, 3),
+ ("hello world", "world", 1, 6),
+ ("hello world", " ", 1, 5),
+ ("hello world", "o w", 1, 4)
+ )
+
+ for string, substring, n, expected_index in test_values:
+ with self.subTest(string=string, substring=substring, n=n):
+ index = helpers.find_nth_occurrence(string, substring, n)
+ self.assertEqual(index, expected_index)
+
+ def test_find_nth_occurrence_returns_none(self):
+ """Test if `find_nth_occurrence` returns None when substring is not found."""
+ test_values = (
+ ("hello", "w", 1, None),
+ ("hello", "w", 2, None),
+ ("hello world", "world", 2, None),
+ ("hello world", " ", 2, None),
+ ("hello world", "o w", 2, None)
+ )
+
+ for string, substring, n, expected_index in test_values:
+ with self.subTest(string=string, substring=substring, n=n):
+ index = helpers.find_nth_occurrence(string, substring, n)
+ self.assertEqual(index, expected_index)
+
+ def test_has_lines_handles_normal_cases(self):
+ """Test if `has_lines` returns True for strings with at least `count` lines."""
+ test_values = (
+ ("hello\nworld", 1, True),
+ ("hello\nworld", 2, True),
+ ("hello\nworld", 3, False),
+ )
+
+ for string, count, expected in test_values:
+ with self.subTest(string=string, count=count):
+ result = helpers.has_lines(string, count)
+ self.assertEqual(result, expected)
+
+ def test_has_lines_handles_empty_string(self):
+ """Test if `has_lines` returns False for empty strings."""
+ test_values = (
+ ("", 0, False),
+ ("", 1, False),
+ )
+
+ for string, count, expected in test_values:
+ with self.subTest(string=string, count=count):
+ result = helpers.has_lines(string, count)
+ self.assertEqual(result, expected)
+
+ def test_has_lines_handles_newline_at_end(self):
+ """Test if `has_lines` ignores one newline at the end."""
+ test_values = (
+ ("hello\nworld\n", 2, True),
+ ("hello\nworld\n", 3, False),
+ ("hello\nworld\n\n", 3, True),
+ )
+
+ for string, count, expected in test_values:
+ with self.subTest(string=string, count=count):
+ result = helpers.has_lines(string, count)
+ self.assertEqual(result, expected)
+
+ def test_pad_base64_correctly(self):
+ """Test if `pad_base64` correctly pads a base64 string."""
+ test_values = (
+ ("", ""),
+ ("a", "a==="),
+ ("aa", "aa=="),
+ ("aaa", "aaa="),
+ ("aaaa", "aaaa"),
+ ("aaaaa", "aaaaa==="),
+ ("aaaaaa", "aaaaaa=="),
+ ("aaaaaaa", "aaaaaaa=")
+ )
+
+ for data, expected in test_values:
+ with self.subTest(data=data):
+ result = helpers.pad_base64(data)
+ self.assertEqual(result, expected)
+
+ def test_remove_subdomain_from_url_correctly(self):
+ """Test if `remove_subdomain_from_url` correctly removes subdomains from URLs."""
+ test_values = (
+ ("https://example.com", "https://example.com"),
+ ("https://www.example.com", "https://example.com"),
+ ("https://sub.example.com", "https://example.com"),
+ ("https://sub.sub.example.com", "https://example.com"),
+ ("https://sub.example.co.uk", "https://example.co.uk"),
+ ("https://sub.sub.example.co.uk", "https://example.co.uk"),
+ ("https://sub.example.co.uk/path", "https://example.co.uk/path"),
+ ("https://sub.sub.example.co.uk/path", "https://example.co.uk/path"),
+ ("https://sub.example.co.uk/path?query", "https://example.co.uk/path?query"),
+ ("https://sub.sub.example.co.uk/path?query", "https://example.co.uk/path?query"),
+ )
+
+ for url, expected in test_values:
+ with self.subTest(url=url):
+ result = helpers.remove_subdomain_from_url(url)
+ self.assertEqual(result, expected)
diff --git a/tests/helpers.py b/tests/helpers.py
index 580848c25..1164828d6 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -388,12 +388,17 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
spec_set = text_channel_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {"id": next(self.discord_id), "name": "channel", "guild": MockGuild()}
+ default_kwargs = {"id": next(self.discord_id), "name": "channel"}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if "mention" not in kwargs:
self.mention = f"#{self.name}"
+ @cached_property
+ def guild(self) -> MockGuild:
+ """Cached guild property."""
+ return MockGuild()
+
class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
@@ -405,12 +410,17 @@ class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
spec_set = voice_channel_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {"id": next(self.discord_id), "name": "channel", "guild": MockGuild()}
+ default_kwargs = {"id": next(self.discord_id), "name": "channel"}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if "mention" not in kwargs:
self.mention = f"#{self.name}"
+ @cached_property
+ def guild(self) -> MockGuild:
+ """Cached guild property."""
+ return MockGuild()
+
# Create data for the DMChannel instance
state = unittest.mock.MagicMock()
@@ -500,10 +510,12 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
super().__init__(**kwargs)
self.me = kwargs.get("me", MockMember())
self.bot = kwargs.get("bot", MockBot())
- self.guild = kwargs.get("guild", MockGuild())
- self.author = kwargs.get("author", MockMember())
- self.channel = kwargs.get("channel", MockTextChannel())
- self.message = kwargs.get("message", MockMessage())
+
+ self.message = kwargs.get("message", MockMessage(guild=self.guild))
+ self.author = kwargs.get("author", self.message.author)
+ self.channel = kwargs.get("channel", self.message.channel)
+ self.guild = kwargs.get("guild", self.channel.guild)
+
self.invoked_from_error_handler = kwargs.get("invoked_from_error_handler", False)
@@ -519,10 +531,12 @@ class MockInteraction(CustomMockMixin, unittest.mock.MagicMock):
super().__init__(**kwargs)
self.me = kwargs.get("me", MockMember())
self.client = kwargs.get("client", MockBot())
- self.guild = kwargs.get("guild", MockGuild())
- self.user = kwargs.get("user", MockMember())
- self.channel = kwargs.get("channel", MockTextChannel())
- self.message = kwargs.get("message", MockMessage())
+
+ self.message = kwargs.get("message", MockMessage(guild=self.guild))
+ self.user = kwargs.get("user", self.message.author)
+ self.channel = kwargs.get("channel", self.message.channel)
+ self.guild = kwargs.get("guild", self.channel.guild)
+
self.invoked_from_error_handler = kwargs.get("invoked_from_error_handler", False)