aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Matteo Bertucci <[email protected]>2021-08-24 14:09:47 +0200
committerGravatar GitHub <[email protected]>2021-08-24 14:09:47 +0200
commit6098a47056d4e5698adb47e87e373c0f3e19fe3a (patch)
treef11ac27f95db06d9638a5df8a1fe4e560504159a /tests
parentEnable debug mode by default (diff)
parentMerge pull request #1775 from python-discord/TizzySaurus-patch-1 (diff)
Merge branch 'main' into enhance/1683/restrict-int-eval
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/exts/events/__init__.py0
-rw-r--r--tests/bot/exts/events/test_code_jams.py (renamed from tests/bot/exts/utils/test_jams.py)66
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py6
-rw-r--r--tests/bot/exts/moderation/infraction/test_utils.py4
-rw-r--r--tests/bot/rules/test_mentions.py26
-rw-r--r--tests/bot/utils/test_message_cache.py214
6 files changed, 272 insertions, 44 deletions
diff --git a/tests/bot/exts/events/__init__.py b/tests/bot/exts/events/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/exts/events/__init__.py
diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/events/test_code_jams.py
index 368a15476..b9ee1e363 100644
--- a/tests/bot/exts/utils/test_jams.py
+++ b/tests/bot/exts/events/test_code_jams.py
@@ -1,14 +1,15 @@
import unittest
-from unittest.mock import AsyncMock, MagicMock, create_autospec
+from unittest.mock import AsyncMock, MagicMock, create_autospec, patch
from discord import CategoryChannel
from discord.ext.commands import BadArgument
from bot.constants import Roles
-from bot.exts.utils import jams
+from bot.exts.events import code_jams
+from bot.exts.events.code_jams import _channels, _cog
from tests.helpers import (
MockAttachment, MockBot, MockCategoryChannel, MockContext,
- MockGuild, MockMember, MockRole, MockTextChannel
+ MockGuild, MockMember, MockRole, MockTextChannel, autospec
)
TEST_CSV = b"""\
@@ -40,7 +41,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.command_user = MockMember([self.admin_role])
self.guild = MockGuild([self.admin_role])
self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild)
- self.cog = jams.CodeJams(self.bot)
+ self.cog = _cog.CodeJams(self.bot)
async def test_message_without_attachments(self):
"""If no link or attachments are provided, commands.BadArgument should be raised."""
@@ -49,7 +50,9 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(BadArgument):
await self.cog.create(self.cog, self.ctx, None)
- async def test_result_sending(self):
+ @patch.object(_channels, "create_team_channel")
+ @patch.object(_channels, "create_team_leader_channel")
+ async def test_result_sending(self, create_leader_channel, create_team_channel):
"""Should call `ctx.send` when everything goes right."""
self.ctx.message.attachments = [MockAttachment()]
self.ctx.message.attachments[0].read = AsyncMock()
@@ -61,14 +64,12 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.ctx.guild.create_role = AsyncMock()
self.ctx.guild.create_role.return_value = team_leaders
- self.cog.create_team_channel = AsyncMock()
- self.cog.create_team_leader_channel = AsyncMock()
self.cog.add_roles = AsyncMock()
await self.cog.create(self.cog, self.ctx, None)
- self.cog.create_team_channel.assert_awaited()
- self.cog.create_team_leader_channel.assert_awaited_once_with(
+ create_team_channel.assert_awaited()
+ create_leader_channel.assert_awaited_once_with(
self.ctx.guild, team_leaders
)
self.ctx.send.assert_awaited_once()
@@ -81,25 +82,24 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
self.ctx.send.assert_awaited_once()
- async def test_category_doesnt_exist(self):
+ @patch.object(_channels, "_send_status_update")
+ async def test_category_doesnt_exist(self, update):
"""Should create a new code jam category."""
subtests = (
[],
- [get_mock_category(jams.MAX_CHANNELS, jams.CATEGORY_NAME)],
- [get_mock_category(jams.MAX_CHANNELS - 2, "other")],
+ [get_mock_category(_channels.MAX_CHANNELS, _channels.CATEGORY_NAME)],
+ [get_mock_category(_channels.MAX_CHANNELS - 2, "other")],
)
- self.cog.send_status_update = AsyncMock()
-
for categories in subtests:
- self.cog.send_status_update.reset_mock()
+ update.reset_mock()
self.guild.reset_mock()
self.guild.categories = categories
with self.subTest(categories=categories):
- actual_category = await self.cog.get_category(self.guild)
+ actual_category = await _channels._get_category(self.guild)
- self.cog.send_status_update.assert_called_once()
+ update.assert_called_once()
self.guild.create_category_channel.assert_awaited_once()
category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"]
@@ -109,45 +109,41 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
async def test_category_channel_exist(self):
"""Should not try to create category channel."""
- expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME)
+ expected_category = get_mock_category(_channels.MAX_CHANNELS - 2, _channels.CATEGORY_NAME)
self.guild.categories = [
- get_mock_category(jams.MAX_CHANNELS - 2, "other"),
+ get_mock_category(_channels.MAX_CHANNELS - 2, "other"),
expected_category,
- get_mock_category(0, jams.CATEGORY_NAME),
+ get_mock_category(0, _channels.CATEGORY_NAME),
]
- actual_category = await self.cog.get_category(self.guild)
+ actual_category = await _channels._get_category(self.guild)
self.assertEqual(expected_category, actual_category)
async def test_channel_overwrites(self):
"""Should have correct permission overwrites for users and roles."""
leader = (MockMember(), True)
members = [leader] + [(MockMember(), False) for _ in range(4)]
- overwrites = self.cog.get_overwrites(members, self.guild)
+ overwrites = _channels._get_overwrites(members, self.guild)
for member, _ in members:
self.assertTrue(overwrites[member].read_messages)
- async def test_team_channels_creation(self):
+ @patch.object(_channels, "_get_overwrites")
+ @patch.object(_channels, "_get_category")
+ @autospec(_channels, "_add_team_leader_roles", pass_mocks=False)
+ async def test_team_channels_creation(self, get_category, get_overwrites):
"""Should create a text channel for a team."""
team_leaders = MockRole()
members = [(MockMember(), True)] + [(MockMember(), False) for _ in range(5)]
category = MockCategoryChannel()
category.create_text_channel = AsyncMock()
- self.cog.get_overwrites = MagicMock()
- self.cog.get_category = AsyncMock()
- self.cog.get_category.return_value = category
- self.cog.add_team_leader_roles = AsyncMock()
-
- await self.cog.create_team_channel(self.guild, "my-team", members, team_leaders)
- self.cog.add_team_leader_roles.assert_awaited_once_with(members, team_leaders)
- self.cog.get_overwrites.assert_called_once_with(members, self.guild)
- self.cog.get_category.assert_awaited_once_with(self.guild)
+ get_category.return_value = category
+ await _channels.create_team_channel(self.guild, "my-team", members, team_leaders)
category.create_text_channel.assert_awaited_once_with(
"my-team",
- overwrites=self.cog.get_overwrites.return_value
+ overwrites=get_overwrites.return_value
)
async def test_jam_roles_adding(self):
@@ -156,7 +152,7 @@ class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase):
leader = MockMember()
members = [(leader, True)] + [(MockMember(), False) for _ in range(4)]
- await self.cog.add_team_leader_roles(members, leader_role)
+ await _channels._add_team_leader_roles(members, leader_role)
leader.add_roles.assert_awaited_once_with(leader_role)
for member, is_leader in members:
@@ -170,5 +166,5 @@ class CodeJamSetup(unittest.TestCase):
def test_setup(self):
"""Should call `bot.add_cog`."""
bot = MockBot()
- jams.setup(bot)
+ code_jams.setup(bot)
bot.add_cog.assert_called_once()
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index b9d527770..f844a9181 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
async def test_voice_unban_user_not_found(self):
"""Should include info to return dict when user was not found from guild."""
self.guild.get_member.return_value = None
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild)
self.assertEqual(result, {"Info": "User was not found in the guild."})
@patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
@@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
notify_pardon_mock.return_value = True
format_user_mock.return_value = "my-user"
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild)
self.assertEqual(result, {
"Member": "my-user",
"DM": "Sent"
@@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
notify_pardon_mock.return_value = False
format_user_mock.return_value = "my-user"
- result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild)
self.assertEqual(result, {
"Member": "my-user",
"DM": "**Failed**"
diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py
index 5f95ced9f..eb256f1fd 100644
--- a/tests/bot/exts/moderation/infraction/test_utils.py
+++ b/tests/bot/exts/moderation/infraction/test_utils.py
@@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):
test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"])
test_cases = [
test_case([], None, None, True),
- test_case([{"id": 123987}], {"id": 123987}, "123987", False),
- test_case([{"id": 123987}], {"id": 123987}, "123987", True)
+ test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False),
+ test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True)
]
for case in test_cases:
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index 6444532f2..f8805ac48 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -2,12 +2,14 @@ from typing import Iterable
from bot.rules import mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
+from tests.helpers import MockMember, MockMessage
-def make_msg(author: str, total_mentions: int) -> MockMessage:
+def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
- return MockMessage(author=author, mentions=list(range(total_mentions)))
+ user_mentions = [MockMember() for _ in range(total_user_mentions)]
+ bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
+ return MockMessage(author=author, mentions=user_mentions+bot_mentions)
class TestMentions(RuleTest):
@@ -48,11 +50,27 @@ class TestMentions(RuleTest):
[make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
4,
- )
+ ),
+ DisallowedCase(
+ [make_msg("bob", 3, 1)],
+ ("bob",),
+ 3,
+ ),
)
await self.run_disallowed(cases)
+ async def test_ignore_bot_mentions(self):
+ """Messages with an allowed amount of mentions, also containing bot mentions."""
+ cases = (
+ [make_msg("bob", 0, 3)],
+ [make_msg("bob", 2, 1)],
+ [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
+ [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
+ )
+
+ await self.run_allowed(cases)
+
def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
last_message = case.recent_messages[0]
return tuple(
diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py
new file mode 100644
index 000000000..04bfd28d1
--- /dev/null
+++ b/tests/bot/utils/test_message_cache.py
@@ -0,0 +1,214 @@
+import unittest
+
+from bot.utils.message_cache import MessageCache
+from tests.helpers import MockMessage
+
+
+# noinspection SpellCheckingInspection
+class TestMessageCache(unittest.TestCase):
+ """Tests for the MessageCache class in the `bot.utils.caching` module."""
+
+ def test_first_append_sets_the_first_value(self):
+ """Test if the first append adds the message to the first cell."""
+ cache = MessageCache(maxlen=10)
+ message = MockMessage()
+
+ cache.append(message)
+
+ self.assertEqual(cache[0], message)
+
+ def test_append_adds_in_the_right_order(self):
+ """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise."""
+ messages = [MockMessage(), MockMessage()]
+
+ cache = MessageCache(maxlen=10, newest_first=False)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages, list(cache))
+
+ cache = MessageCache(maxlen=10, newest_first=True)
+ for msg in messages:
+ cache.append(msg)
+ self.assertListEqual(messages[::-1], list(cache))
+
+ def test_appending_over_maxlen_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages."""
+ cache = MessageCache(maxlen=2)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_appending_over_maxlen_with_newest_first_removes_oldest(self):
+ """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True."""
+ cache = MessageCache(maxlen=2, newest_first=True)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ self.assertListEqual(messages[:0:-1], list(cache))
+
+ def test_pop_removes_from_the_end(self):
+ """Test if a pop removes the right-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.pop()
+
+ self.assertEqual(msg, messages[-1])
+ self.assertListEqual(messages[:-1], list(cache))
+
+ def test_popleft_removes_from_the_beginning(self):
+ """Test if a popleft removes the left-most message."""
+ cache = MessageCache(maxlen=3)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ msg = cache.popleft()
+
+ self.assertEqual(msg, messages[0])
+ self.assertListEqual(messages[1:], list(cache))
+
+ def test_clear(self):
+ """Test if a clear makes the cache empty."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+
+ for msg in messages:
+ cache.append(msg)
+ cache.clear()
+
+ self.assertListEqual(list(cache), [])
+ self.assertEqual(len(cache), 0)
+
+ def test_get_message_returns_the_message(self):
+ """Test if get_message returns the cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertEqual(cache.get_message(1234), message)
+
+ def test_get_message_returns_none(self):
+ """Test if get_message returns None for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIsNone(cache.get_message(4321))
+
+ def test_update_replaces_old_element(self):
+ """Test if an update replaced the old message with the same ID."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+ message = MockMessage(id=1234)
+ cache.update(message)
+
+ self.assertIs(cache.get_message(1234), message)
+ self.assertEqual(len(cache), 1)
+
+ def test_contains_returns_true_for_cached_message(self):
+ """Test if contains returns True for an ID of a cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertIn(1234, cache)
+
+ def test_contains_returns_false_for_non_cached_message(self):
+ """Test if contains returns False for an ID of a non-cached message."""
+ cache = MessageCache(maxlen=5)
+ message = MockMessage(id=1234)
+
+ cache.append(message)
+
+ self.assertNotIn(4321, cache)
+
+ def test_indexing(self):
+ """Test if the cache returns the correct messages by index."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(5)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in range(-5, 5):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(cache[current_loop], messages[current_loop])
+
+ def test_bad_index_raises_index_error(self):
+ """Test if the cache raises IndexError for invalid indices."""
+ cache = MessageCache(maxlen=5)
+ messages = [MockMessage() for _ in range(3)]
+ test_cases = (-10, -4, 3, 4, 5)
+
+ for msg in messages:
+ cache.append(msg)
+
+ for current_loop in test_cases:
+ with self.subTest(current_loop=current_loop):
+ with self.assertRaises(IndexError):
+ cache[current_loop]
+
+ def test_slicing_with_unfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache is not yet fully filled."""
+ sizes = (5, 10, 55, 101)
+
+ slices = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1)
+ )
+
+ for size in sizes:
+ cache = MessageCache(maxlen=size)
+ messages = [MockMessage() for _ in range(size // 3 * 2)]
+
+ for msg in messages:
+ cache.append(msg)
+
+ for slice_ in slices:
+ with self.subTest(current_loop=(size, slice_)):
+ self.assertListEqual(cache[slice_], messages[slice_])
+
+ def test_slicing_with_overfilled_cache(self):
+ """Test if slicing returns the correct messages if the cache was appended with more messages it can contain."""
+ sizes = (5, 10, 55, 101)
+
+ slices = (
+ slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2),
+ slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2),
+ slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1)
+ )
+
+ for size in sizes:
+ cache = MessageCache(maxlen=size)
+ messages = [MockMessage() for _ in range(size * 3 // 2)]
+
+ for msg in messages:
+ cache.append(msg)
+ messages = messages[size // 2:]
+
+ for slice_ in slices:
+ with self.subTest(current_loop=(size, slice_)):
+ self.assertListEqual(cache[slice_], messages[slice_])
+
+ def test_length(self):
+ """Test if len returns the correct number of items in the cache."""
+ cache = MessageCache(maxlen=5)
+
+ for current_loop in range(10):
+ with self.subTest(current_loop=current_loop):
+ self.assertEqual(len(cache), min(current_loop, 5))
+ cache.append(MockMessage())