diff options
author | 2023-04-09 00:34:55 +0300 | |
---|---|---|
committer | 2023-04-09 00:34:55 +0300 | |
commit | c74a3c645d23d22206182813478f6c9b74812ed8 (patch) | |
tree | 62b648f92c2607a5ed94f90a86b09698d16f8c98 /tests | |
parent | Fix: use nomination.user_id instead of id in get_or_fetch_member (diff) | |
parent | Merge pull request #2517 from python-discord/thread_filter (diff) |
Merge branch 'main' into 2302-activity-in-reviews
Diffstat (limited to 'tests')
24 files changed, 712 insertions, 1507 deletions
diff --git a/tests/_autospec.py b/tests/_autospec.py index ee2fc1973..ecff6bcbe 100644 --- a/tests/_autospec.py +++ b/tests/_autospec.py @@ -1,5 +1,6 @@ import contextlib import functools +import pkgutil import unittest.mock from typing import Callable @@ -51,7 +52,7 @@ def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) # Import the target if it's a string. # This is to support both object and string targets like patch.multiple. if type(target) is str: - target = unittest.mock._importer(target) + target = pkgutil.resolve_name(target) def decorator(func): for attribute in attributes: diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filtering/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/exts/filters/__init__.py +++ b/tests/bot/exts/filtering/__init__.py diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py new file mode 100644 index 000000000..a5cddf8d9 --- /dev/null +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -0,0 +1,276 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock, patch + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.unique import discord_token +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token +from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec + + +class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): + """Tests the DiscordTokenFilter class.""" + + def setUp(self): + """Adds the filter, a bot, and a message to the instance for usage in tests.""" + now = arrow.utcnow().timestamp() + self.filter = DiscordTokenFilter({ + "id": 1, + "content": "discord_token", + "description": None, + "settings": {}, + "additional_settings": {}, + "created_at": now, + "updated_at": now + }) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + + member = MockMember(id=123) + channel = MockTextChannel(id=345) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg) + + def test_extract_user_id_valid(self): + """Should consider user IDs valid if they decode into an integer ID.""" + id_pairs = ( + ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), + ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), + ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), + ) + + for token_id, user_id in id_pairs: + with self.subTest(token_id=token_id): + result = DiscordTokenFilter.extract_user_id(token_id) + self.assertEqual(result, user_id) + + def test_extract_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = DiscordTokenFilter.extract_user_id(user_id) + self.assertIsNone(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = DiscordTokenFilter.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = DiscordTokenFilter.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_is_valid_hmac_valid(self): + """Should consider an HMAC valid if it has at least 3 unique characters.""" + valid_hmacs = ( + "VXmErH7j511turNpfURmb0rVNm8", + "Ysnu2wacjaKs7qnoo46S8Dm2us8", + "sJf6omBPORBPju3WJEIAcwW9Zds", + "s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for hmac in valid_hmacs: + with self.subTest(msg=hmac): + result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) + self.assertTrue(result) + + def test_is_invalid_hmac_invalid(self): + """Should consider an HMAC invalid if has fewer than 3 unique characters.""" + invalid_hmacs = ( + ("xxxxxxxxxxxxxxxxxx", "Single character"), + ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), + ("ASFasfASFasfASFASsf", "Three characters alternating-case"), + ("asdasdasdasdasdasdasd", "Three characters one case"), + ) + + for hmac, msg in invalid_hmacs: + with self.subTest(msg=msg): + result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) + self.assertFalse(result) + + async def test_no_trigger_when_no_token(self): + """False should be returned if the message doesn't contain a Discord token.""" + return_value = await self.filter.triggered_on(self.ctx) + + self.assertFalse(return_value) + + @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") + @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") + @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") + def test_find_token_valid_match( + self, + token_re, + token_cls, + extract_user_id, + is_valid_timestamp, + is_maybe_valid_hmac, + ): + """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + is_maybe_valid_hmac.return_value = True + + return_value = DiscordTokenFilter.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + + @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") + @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") + @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") + def test_find_token_invalid_matches( + self, + token_re, + token_cls, + extract_user_id, + is_valid_timestamp, + is_maybe_valid_hmac, + ): + """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + extract_user_id.return_value = None + is_valid_timestamp.return_value = False + is_maybe_valid_hmac.return_value = False + + return_value = DiscordTokenFilter.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = discord_token.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = discord_token.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = discord_token.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token) + + self.assertEqual(return_value, log_message.format.return_value) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE") + @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member") + async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message): + """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + unknown_user_log_message.format.return_value = " Partner" + get_or_fetch_member.return_value = None + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") + async def test_format_userid_log_message_bot(self, known_user_log_message): + """Should correctly format the user ID portion when the ID belongs to a known bot.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + known_user_log_message.format.return_value = " Partner" + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") + async def test_format_log_message_user_token_user(self, user_token_message): + """Should correctly format the user ID portion when the ID belongs to a known user.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + user_token_message.format.return_value = "Partner" + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (user_token_message.format.return_value, True)) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py new file mode 100644 index 000000000..827d267d2 --- /dev/null +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, patch + +import arrow + +from bot.constants import Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import extension +from bot.exts.filtering._filter_lists.extension import ExtensionsList +from bot.exts.filtering._filter_lists.filter_list import ListType +from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel + +BOT = MockBot() + + +class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): + """Test the ExtensionsList class.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.filter_list = ExtensionsList(MagicMock()) + now = arrow.utcnow().timestamp() + filters = [] + self.whitelist = [".first", ".second", ".third"] + for i, filter_content in enumerate(self.whitelist, start=1): + filters.append({ + "id": i, "content": filter_content, "description": None, "settings": {}, + "additional_settings": {}, "created_at": now, "updated_at": now # noqa: P103 + }) + self.filter_list.add_list({ + "id": 1, + "list_type": 1, + "created_at": now, + "updated_at": now, + "settings": {}, + "filters": filters + }) + + self.message = MockMessage() + member = MockMember(id=123) + channel = MockTextChannel(id=345) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message) + + @patch("bot.instance", BOT) + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" + attachment = MockAttachment(filename="python.first") + ctx = self.ctx.replace(attachments=[attachment]) + + result = await self.filter_list.actions_for(ctx) + + self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) + + @patch("bot.instance", BOT) + async def test_message_without_attachment(self): + """Messages without attachments should return no triggers, messages, or actions.""" + result = await self.filter_list.actions_for(self.ctx) + + self.assertEqual(result, (None, [], {})) + + @patch("bot.instance", BOT) + async def test_message_with_illegal_extension(self): + """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" + attachment = MockAttachment(filename="python.disallowed") + ctx = self.ctx.replace(attachments=[attachment]) + + result = await self.filter_list.actions_for(ctx) + + self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) + + @patch("bot.instance", BOT) + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site.""" + attachment = MockAttachment(filename="python.py") + ctx = self.ctx.replace(attachments=[attachment]) + + await self.filter_list.actions_for(ctx) + + self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + + @patch("bot.instance", BOT) + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt/.json/.csv file should result in the correct embed.""" + test_values = ( + ("text", ".txt"), + ("json", ".json"), + ("csv", ".csv"), + ) + + for file_name, disallowed_extension in test_values: + with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): + + attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") + ctx = self.ctx.replace(attachments=[attachment]) + + await self.filter_list.actions_for(ctx) + + self.assertEqual( + ctx.dm_embed, + extension.TXT_EMBED_DESCRIPTION.format( + blocked_extension=disallowed_extension, + ) + ) + + @patch("bot.instance", BOT) + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + ctx = self.ctx.replace(attachments=[attachment]) + + await self.filter_list.actions_for(ctx) + meta_channel = BOT.get_channel(Channels.meta) + + self.assertEqual( + ctx.dm_embed, + extension.DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=", ".join(self.whitelist), + joined_blacklist=".disallowed", + meta_channel_mention=meta_channel.mention + ) + ) + + @patch("bot.instance", BOT) + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], ["`.disallowed`"]), + ([".disallowed"], ["`.disallowed`"]), + ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) + result = await self.filter_list.actions_for(ctx) + self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py new file mode 100644 index 000000000..5a289c1cf --- /dev/null +++ b/tests/bot/exts/filtering/test_settings.py @@ -0,0 +1,20 @@ +import unittest + +import bot.exts.filtering._settings +from bot.exts.filtering._settings import create_settings + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def test_create_settings_returns_none_for_empty_data(self): + """`create_settings` should return a tuple of two Nones when passed an empty dict.""" + result = create_settings({}) + + self.assertEqual(result, (None, None)) + + def test_unrecognized_entry_makes_a_warning(self): + """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" + create_settings({"abcd": {}}) + + self.assertIn("abcd", bot.exts.filtering._settings._already_warned) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py new file mode 100644 index 000000000..3ae0b5ab5 --- /dev/null +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -0,0 +1,218 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( + Infraction, InfractionAndNotification, InfractionDuration +) +from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM +from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + def test_role_bypass_is_off_for_user_without_roles(self): + """The role bypass should trigger when a user has no roles.""" + member = MockMember() + self.ctx.author = member + bypass_entry = RoleBypass(bypass_roles=["123"]) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_role_bypass_is_on_for_a_user_with_the_right_role(self): + """The role bypass should not trigger when the user has one of its roles.""" + cases = ( + ([123], ["123"]), + ([123, 234], ["123"]), + ([123], ["123", "234"]), + ([123, 234], ["123", "234"]) + ) + + for user_role_ids, bypasses in cases: + with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses): + user_roles = [MockRole(id=role_id) for role_id in user_role_ids] + member = MockMember(roles=user_roles) + self.ctx.author = member + bypass_entry = RoleBypass(bypass_roles=bypasses) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_for_empty_channel_scope(self): + """A filter is enabled for all channels by default.""" + channel = MockTextChannel() + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_doesnt_trigger_for_disabled_channel(self): + """A filter shouldn't trigger if it's been disabled in the channel.""" + channel = MockTextChannel(id=123) + scope = ChannelScope( + disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_in_disabled_category(self): + """A filter shouldn't trigger if it's been disabled in the category.""" + channel = MockTextChannel(category=MockCategoryChannel(id=456)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_triggers_in_enabled_channel_in_disabled_category(self): + """A filter should trigger in an enabled channel even if it's been disabled in the category.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_triggers_inside_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"] + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_doesnt_trigger_outside_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"] + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"] + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_filtering_dms_when_necessary(self): + """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" + cases = ( + (True, MockDMChannel(), True), + (False, MockDMChannel(), False), + (True, MockTextChannel(), True), + (False, MockTextChannel(), True) + ) + + for apply_in_dms, channel, expected in cases: + with self.subTest(apply_in_dms=apply_in_dms, channel=channel): + filter_dms = FilterDM(filter_dm=apply_in_dms) + self.ctx.channel = channel + + result = filter_dms.triggers_on(self.ctx) + + self.assertEqual(expected, result) + + def test_infraction_merge_of_same_infraction_type(self): + """When both infractions are of the same type, the one with the longer duration wins.""" + infraction1 = InfractionAndNotification( + infraction_type="TIMEOUT", + infraction_reason="hi", + infraction_duration=InfractionDuration(10), + dm_content="how", + dm_embed="what is", + infraction_channel=0 + ) + infraction2 = InfractionAndNotification( + infraction_type="TIMEOUT", + infraction_reason="there", + infraction_duration=InfractionDuration(20), + dm_content="are you", + dm_embed="your name", + infraction_channel=0 + ) + + result = infraction1.union(infraction2) + + self.assertDictEqual( + result.dict(), + { + "infraction_type": Infraction.TIMEOUT, + "infraction_reason": "there", + "infraction_duration": InfractionDuration(20.0), + "dm_content": "are you", + "dm_embed": "your name", + "infraction_channel": 0 + } + ) + + def test_infraction_merge_of_different_infraction_types(self): + """If there are two different infraction types, the one higher up the hierarchy should be picked.""" + infraction1 = InfractionAndNotification( + infraction_type="TIMEOUT", + infraction_reason="hi", + infraction_duration=InfractionDuration(20), + dm_content="", + dm_embed="", + infraction_channel=0 + ) + infraction2 = InfractionAndNotification( + infraction_type="BAN", + infraction_reason="", + infraction_duration=InfractionDuration(10), + dm_content="there", + dm_embed="", + infraction_channel=0 + ) + + result = infraction1.union(infraction2) + + self.assertDictEqual( + result.dict(), + { + "infraction_type": Infraction.BAN, + "infraction_reason": "", + "infraction_duration": InfractionDuration(10), + "dm_content": "there", + "dm_embed": "", + "infraction_channel": 0 + } + ) diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py new file mode 100644 index 000000000..03fa6b4b9 --- /dev/null +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -0,0 +1,49 @@ +import unittest + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class TokenFilterTests(unittest.IsolatedAsyncioTestCase): + """Test functionality of the token filter.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + async def test_token_filter_triggers(self): + """The filter should evaluate to True only if its token is found in the context content.""" + test_cases = ( + (r"hi", "oh hi there", True), + (r"hi", "goodbye", False), + (r"bla\d{2,4}", "bla18", True), + (r"bla\d{2,4}", "bla1", False), + # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 + (r"TOKEN", "https://google.com TOKEN", True), + (r"TOKEN", "https://google.com something else", False) + ) + now = arrow.utcnow().timestamp() + + for pattern, content, expected in test_cases: + with self.subTest( + pattern=pattern, + content=content, + expected=expected, + ): + filter_ = TokenFilter({ + "id": 1, + "content": pattern, + "description": None, + "settings": {}, + "additional_settings": {}, + "created_at": now, + "updated_at": now + }) + self.ctx.content = content + result = await filter_.triggered_on(self.ctx) + self.assertEqual(result, expected) diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py deleted file mode 100644 index 7282334e2..000000000 --- a/tests/bot/exts/filters/test_antimalware.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.constants import Channels, STAFF_ROLES -from bot.exts.filters import antimalware -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.message.webhook_id = None - self.message.author.bot = None - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_webhook_message_with_illegal_extension(self): - """A webhook message containing an illegal extension should be ignored.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.webhook_id = 697140105563078727 - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_bot_message_with_illegal_extension(self): - """A bot message containing an illegal extension should be ignored.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.author.bot = 409107086526644234 - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt/.json/.csv file should result in the correct embed.""" - test_values = ( - ("text", ".txt"), - ("json", ".json"), - ("csv", ".csv"), - ) - - for file_name, disallowed_extension in test_values: - with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): - - attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual( - embed.description, - antimalware.TXT_EMBED_DESCRIPTION.format.return_value - ) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with( - blocked_extension=disallowed_extension, - cmd_channel_mention=cmd_channel.mention - ) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): - """Tests setup of the `AntiMalware` cog.""" - - async def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - await antimalware.setup(bot) - bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py deleted file mode 100644 index 6a0e4fded..000000000 --- a/tests/bot/exts/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.exts.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py deleted file mode 100644 index e47cf627b..000000000 --- a/tests/bot/exts/filters/test_filtering.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot.exts.filters import filtering -from tests.helpers import MockBot, autospec - - -class FilteringCogTests(unittest.IsolatedAsyncioTestCase): - """Tests the `Filtering` cog.""" - - def setUp(self): - """Instantiate the bot and cog.""" - self.bot = MockBot() - with patch("pydis_core.utils.scheduling.create_task", new=lambda task, **_: task.close()): - self.cog = filtering.Filtering(self.bot) - - @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) - async def test_token_filter(self): - """Ensure that a filter token is correctly detected in a message.""" - messages = { - "": False, - "no matches": False, - "TOKEN": True, - - # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 - "https://google.com TOKEN": True, - "https://google.com something else": False, - } - - for message, match in messages.items(): - with self.subTest(input=message, match=match): - result, _ = await self.cog._has_watch_regex_match(message) - - self.assertEqual( - match, - bool(result), - msg=f"Hit was {'expected' if match else 'not expected'} for this input." - ) - if result: - self.assertEqual("TOKEN", result.group()) diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py deleted file mode 100644 index c1f3762ac..000000000 --- a/tests/bot/exts/filters/test_token_remover.py +++ /dev/null @@ -1,409 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.exts.filters import token_remover -from bot.exts.filters.token_remover import Token, TokenRemover -from bot.exts.moderation.modlog import ModLog -from bot.utils.messages import format_user -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): - """Tests the `TokenRemover` cog.""" - - def setUp(self): - """Adds the cog, a bot, and a message to the instance for usage in tests.""" - self.bot = MockBot() - self.cog = TokenRemover(bot=self.bot) - - self.msg = MockMessage(id=555, content="hello world") - self.msg.channel.mention = "#lemonade-stand" - self.msg.guild.get_member.return_value.bot = False - self.msg.guild.get_member.return_value.__str__.return_value = "Woody" - self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) - self.msg.author.display_avatar.url = "picture-lemon.png" - - def test_extract_user_id_valid(self): - """Should consider user IDs valid if they decode into an integer ID.""" - id_pairs = ( - ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), - ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), - ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), - ) - - for token_id, user_id in id_pairs: - with self.subTest(token_id=token_id): - result = TokenRemover.extract_user_id(token_id) - self.assertEqual(result, user_id) - - def test_extract_user_id_invalid(self): - """Should consider non-digit and non-ASCII IDs invalid.""" - ids = ( - ("SGVsbG8gd29ybGQ", "non-digit ASCII"), - ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), - ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), - ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), - ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for user_id, msg in ids: - with self.subTest(msg=msg): - result = TokenRemover.extract_user_id(user_id) - self.assertIsNone(result) - - def test_is_valid_timestamp_valid(self): - """Should consider timestamps valid if they're greater than the Discord epoch.""" - timestamps = ( - "XsyRkw", - "Xrim9Q", - "XsyR-w", - "XsySD_", - "Dn9r_A", - ) - - for timestamp in timestamps: - with self.subTest(timestamp=timestamp): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertTrue(result) - - def test_is_valid_timestamp_invalid(self): - """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" - timestamps = ( - ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), - ("ew", "123"), - ("AoIKgA", "42076800"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for timestamp, msg in timestamps: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertFalse(result) - - def test_is_valid_hmac_valid(self): - """Should consider an HMAC valid if it has at least 3 unique characters.""" - valid_hmacs = ( - "VXmErH7j511turNpfURmb0rVNm8", - "Ysnu2wacjaKs7qnoo46S8Dm2us8", - "sJf6omBPORBPju3WJEIAcwW9Zds", - "s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for hmac in valid_hmacs: - with self.subTest(msg=hmac): - result = TokenRemover.is_maybe_valid_hmac(hmac) - self.assertTrue(result) - - def test_is_invalid_hmac_invalid(self): - """Should consider an HMAC invalid if has fewer than 3 unique characters.""" - invalid_hmacs = ( - ("xxxxxxxxxxxxxxxxxx", "Single character"), - ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), - ("ASFasfASFasfASFASsf", "Three characters alternating-case"), - ("asdasdasdasdasdasdasd", "Three characters one case"), - ) - - for hmac, msg in invalid_hmacs: - with self.subTest(msg=msg): - result = TokenRemover.is_maybe_valid_hmac(hmac) - self.assertFalse(result) - - def test_mod_log_property(self): - """The `mod_log` property should ask the bot to return the `ModLog` cog.""" - self.bot.get_cog.return_value = 'lemon' - self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) - self.bot.get_cog.assert_called_once_with('ModLog') - - async def test_on_message_edit_uses_on_message(self): - """The edit listener should delegate handling of the message to the normal listener.""" - self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - - await self.cog.on_message_edit(MockMessage(), self.msg) - self.cog.on_message.assert_awaited_once_with(self.msg) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_takes_action(self, find_token_in_message, take_action): - """Should take action if a valid token is found when a message is sent.""" - cog = TokenRemover(self.bot) - found_token = "foobar" - find_token_in_message.return_value = found_token - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_awaited_once_with(cog, self.msg, found_token) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): - """Shouldn't take action if a valid token isn't found when a message is sent.""" - cog = TokenRemover(self.bot) - find_token_in_message.return_value = False - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_not_awaited() - - @autospec(TokenRemover, "find_token_in_message") - async def test_on_message_ignores_dms_bots(self, find_token_in_message): - """Shouldn't parse a message if it is a DM or authored by a bot.""" - cog = TokenRemover(self.bot) - dm_msg = MockMessage(guild=None) - bot_msg = MockMessage(author=MagicMock(bot=True)) - - for msg in (dm_msg, bot_msg): - await cog.on_message(msg) - find_token_in_message.assert_not_called() - - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_no_matches(self, token_re): - """None should be returned if the regex matches no tokens in a message.""" - token_re.finditer.return_value = () - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") - @autospec("bot.exts.filters.token_remover", "Token") - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_valid_match( - self, - token_re, - token_cls, - extract_user_id, - is_valid_timestamp, - is_maybe_valid_hmac, - ): - """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" - matches = [ - mock.create_autospec(Match, spec_set=True, instance=True), - mock.create_autospec(Match, spec_set=True, instance=True), - ] - tokens = [ - mock.create_autospec(Token, spec_set=True, instance=True), - mock.create_autospec(Token, spec_set=True, instance=True), - ] - - token_re.finditer.return_value = matches - token_cls.side_effect = tokens - extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid. - is_valid_timestamp.return_value = True - is_maybe_valid_hmac.return_value = True - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertEqual(tokens[1], return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") - @autospec("bot.exts.filters.token_remover", "Token") - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_invalid_matches( - self, - token_re, - token_cls, - extract_user_id, - is_valid_timestamp, - is_maybe_valid_hmac, - ): - """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" - token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] - token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) - extract_user_id.return_value = None - is_valid_timestamp.return_value = False - is_maybe_valid_hmac.return_value = False - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - def test_regex_invalid_tokens(self): - """Messages without anything looking like a token are not matched.""" - tokens = ( - "", - "lemon wins", - "..", - "x.y", - "x.y.", - ".y.z", - ".y.", - "..z", - "x..z", - " . . ", - "\n.\n.\n", - "hellö.world.bye", - "base64.nötbåse64.morebase64", - "19jd3J.dfkm3d.€víł§tüff", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertEqual(len(results), 0) - - def test_regex_valid_tokens(self): - """Messages that look like tokens should be matched.""" - # Don't worry, these tokens have been invalidated. - tokens = ( - "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", - "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", - "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.fullmatch(token) - self.assertIsNotNone(results, f"{token} was not matched by the regex") - - def test_regex_matches_multiple_valid(self): - """Should support multiple matches in the middle of a string.""" - token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" - token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" - message = f"garbage {token_1} hello {token_2} world" - - results = token_remover.TOKEN_RE.finditer(message) - results = [match[0] for match in results] - self.assertCountEqual((token_1, token_2), results) - - @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") - def test_format_log_message(self, log_message): - """Should correctly format the log message with info from the message and token.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - log_message.format.return_value = "Howdy" - - return_value = TokenRemover.format_log_message(self.msg, token) - - self.assertEqual(return_value, log_message.format.return_value) - log_message.format.assert_called_once_with( - author=format_user(self.msg.author), - channel=self.msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4", - ) - - @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE") - async def test_format_userid_log_message_unknown(self, unknown_user_log_message,): - """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - unknown_user_log_message.format.return_value = " Partner" - msg = MockMessage(id=555, content="hello world") - msg.guild.get_member.return_value = None - msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found") - - return_value = await TokenRemover.format_userid_log_message(msg, token) - - self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) - unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332) - - @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") - async def test_format_userid_log_message_bot(self, known_user_log_message): - """Should correctly format the user ID portion when the ID belongs to a known bot.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - known_user_log_message.format.return_value = " Partner" - msg = MockMessage(id=555, content="hello world") - msg.guild.get_member.return_value.__str__.return_value = "Sam" - msg.guild.get_member.return_value.bot = True - - return_value = await TokenRemover.format_userid_log_message(msg, token) - - self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) - - known_user_log_message.format.assert_called_once_with( - user_id=472265943062413332, - user_name="Sam", - kind="BOT", - ) - - @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") - async def test_format_log_message_user_token_user(self, user_token_message): - """Should correctly format the user ID portion when the ID belongs to a known user.""" - token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - user_token_message.format.return_value = "Partner" - - return_value = await TokenRemover.format_userid_log_message(self.msg, token) - - self.assertEqual(return_value, (user_token_message.format.return_value, True)) - user_token_message.format.assert_called_once_with( - user_id=467223230650777641, - user_name="Woody", - kind="USER", - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - @autospec("bot.exts.filters.token_remover", "log") - @autospec(TokenRemover, "format_log_message", "format_userid_log_message") - async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property): - """Should delete the message and send a mod log.""" - cog = TokenRemover(self.bot) - mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = mock.create_autospec(Token, spec_set=True, instance=True) - token.user_id = "no-id" - log_msg = "testing123" - userid_log_message = "userid-log-message" - - mod_log_property.return_value = mod_log - format_log_message.return_value = log_msg - format_userid_log_message.return_value = (userid_log_message, True) - - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - - format_log_message.assert_called_once_with(self.msg, token) - format_userid_log_message.assert_called_once_with(self.msg, token) - logger.debug.assert_called_with(log_msg) - self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - - mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=constants.Icons.token_removed, - colour=Colour(constants.Colours.soft_red), - title="Token removed!", - text=log_msg + "\n" + userid_log_message, - thumbnail=self.msg.author.display_avatar.url, - channel_id=constants.Channels.mod_alerts, - ping_everyone=True, - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - async def test_take_action_delete_failure(self, mod_log_property): - """Shouldn't send any messages if the token message can't be deleted.""" - cog = TokenRemover(self.bot) - mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) - self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - - token = mock.create_autospec(Token, spec_set=True, instance=True) - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the token_remover extension.""" - - @autospec("bot.exts.filters.token_remover", "TokenRemover") - async def test_extension_setup(self, cog): - """The TokenRemover cog should be added.""" - bot = MockBot() - await token_remover.setup(bot) - - cog.assert_called_once_with(bot) - bot.add_cog.assert_awaited_once() - self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py deleted file mode 100644 index 0d570f5a3..000000000 --- a/tests/bot/rules/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest -from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple - -from tests.helpers import MockMessage - - -class DisallowedCase(NamedTuple): - """Encapsulation for test cases expected to fail.""" - recent_messages: List[MockMessage] - culprits: Iterable[str] - n_violations: int - - -class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): - """ - Abstract class for antispam rule test cases. - - Tests for specific rules should inherit from `RuleTest` and implement - `relevant_messages` and `get_report`. Each instance should also set the - `apply` and `config` attributes as necessary. - - The execution of test cases can then be delegated to the `run_allowed` - and `run_disallowed` methods. - """ - - apply: Callable # The tested rule's apply function - config: Dict[str, int] - - async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: - """Run all `cases` against `self.apply` expecting them to pass.""" - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config, - ): - self.assertIsNone( - await self.apply(last_message, recent_messages, self.config) - ) - - async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: - """Run all `cases` against `self.apply` expecting them to fail.""" - for case in cases: - recent_messages, culprits, n_violations = case - last_message = recent_messages[0] - relevant_messages = self.relevant_messages(case) - desired_output = ( - self.get_report(case), - culprits, - relevant_messages, - ) - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - n_violations=n_violations, - config=self.config, - ): - self.assertTupleEqual( - await self.apply(last_message, recent_messages, self.config), - desired_output, - ) - - @abstractmethod - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - """Give expected relevant messages for `case`.""" - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_report(self, case: DisallowedCase) -> str: - """Give expected error report for `case`.""" - raise NotImplementedError # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py deleted file mode 100644 index d7e779221..000000000 --- a/tests/bot/rules/test_attachments.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Iterable - -from bot.rules import attachments -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_attachments: int) -> MockMessage: - """Builds a message with `total_attachments` attachments.""" - return MockMessage(author=author, attachments=list(range(total_attachments))) - - -class AttachmentRuleTests(RuleTest): - """Tests applying the `attachments` antispam rule.""" - - def setUp(self): - self.apply = attachments.apply - self.config = {"max": 5, "interval": 10} - - async def test_allows_messages_without_too_many_attachments(self): - """Messages without too many attachments are allowed as-is.""" - cases = ( - [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], - [make_msg("bob", 2), make_msg("bob", 2)], - [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_with_too_many_attachments(self): - """Messages with too many attachments trigger the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], - ("bob",), - 10, - ), - DisallowedCase( - [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], - ("bob",), - 6, - ), - DisallowedCase( - [make_msg("alice", 6)], - ("alice",), - 6, - ), - DisallowedCase( - [make_msg("alice", 1) for _ in range(6)], - ("alice",), - 6, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if ( - msg.author == last_message.author - and len(msg.attachments) > 0 - ) - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py deleted file mode 100644 index 03682966b..000000000 --- a/tests/bot/rules/test_burst.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Iterable - -from bot.rules import burst -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: - """ - Init a MockMessage instance with author set to `author`. - - This serves as a shorthand / alias to keep the test cases visually clean. - """ - return MockMessage(author=author) - - -class BurstRuleTests(RuleTest): - """Tests the `burst` antispam rule.""" - - def setUp(self): - self.apply = burst.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("bob"), make_msg("bob")], - [make_msg("bob"), make_msg("alice"), make_msg("bob")], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the amount of messages exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("bob")], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], - ("bob",), - 3, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py deleted file mode 100644 index 3275143d5..000000000 --- a/tests/bot/rules/test_burst_shared.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable - -from bot.rules import burst_shared -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: - """ - Init a MockMessage instance with the passed arg. - - This serves as a shorthand / alias to keep the test cases visually clean. - """ - return MockMessage(author=author) - - -class BurstSharedRuleTests(RuleTest): - """Tests the `burst_shared` antispam rule.""" - - def setUp(self): - self.apply = burst_shared.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """ - Cases that do not violate the rule. - - There really isn't more to test here than a single case. - """ - cases = ( - [make_msg("spongebob"), make_msg("patrick")], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the amount of messages exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("bob")], - {"bob"}, - 3, - ), - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], - {"bob", "alice"}, - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return case.recent_messages - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py deleted file mode 100644 index f1e3c76a7..000000000 --- a/tests/bot/rules/test_chars.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import chars -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_chars: int) -> MockMessage: - """Build a message with arbitrary content of `n_chars` length.""" - return MockMessage(author=author, content="A" * n_chars) - - -class CharsRuleTests(RuleTest): - """Tests the `chars` antispam rule.""" - - def setUp(self): - self.apply = chars.apply - self.config = { - "max": 20, # Max allowed sum of chars per user - "interval": 10, - } - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of chars within limit.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 20)], - [make_msg("bob", 15), make_msg("alice", 15)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the total amount of chars exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob", 21)], - ("bob",), - 21, - ), - DisallowedCase( - [make_msg("bob", 15), make_msg("bob", 15)], - ("bob",), - 30, - ), - DisallowedCase( - [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], - ("alice",), - 30, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py deleted file mode 100644 index 66c2d9f92..000000000 --- a/tests/bot/rules/test_discord_emojis.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Iterable - -from bot.rules import discord_emojis -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - -discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> -unicode_emoji = "🧪" - - -def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage: - """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" - return MockMessage(author=author, content=emoji * n_emojis) - - -class DiscordEmojisRuleTests(RuleTest): - """Tests for the `discord_emojis` antispam rule.""" - - def setUp(self): - self.apply = discord_emojis.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of discord and unicode emojis within limit.""" - cases = ( - [make_msg("bob", 2)], - [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], - [make_msg("bob", 2, unicode_emoji)], - [ - make_msg("alice", 1, unicode_emoji), - make_msg("bob", 2, unicode_emoji), - make_msg("alice", 1, unicode_emoji) - ], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with more than the allowed amount of discord and unicode emojis.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], - ("alice",), - 4, - ), - DisallowedCase( - [make_msg("bob", 3, unicode_emoji)], - ("bob",), - 3, - ), - DisallowedCase( - [ - make_msg("alice", 2, unicode_emoji), - make_msg("bob", 2, unicode_emoji), - make_msg("alice", 2, unicode_emoji) - ], - ("alice",), - 4 - ) - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py deleted file mode 100644 index 9bd886a77..000000000 --- a/tests/bot/rules/test_duplicates.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import duplicates -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, content: str) -> MockMessage: - """Give a MockMessage instance with `author` and `content` attrs.""" - return MockMessage(author=author, content=content) - - -class DuplicatesRuleTests(RuleTest): - """Tests the `duplicates` antispam rule.""" - - def setUp(self): - self.apply = duplicates.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("alice", "A"), make_msg("alice", "A")], - [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate - [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with too many duplicate messages from the same author.""" - cases = ( - DisallowedCase( - [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], - ("alice",), - 3, - ), - DisallowedCase( - [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], - ("bob",), - 3, # 4 duplicate messages, but only 3 from bob - ), - DisallowedCase( - [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], - ("bob",), - 3, # 4 message from bob, but only 3 duplicates - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if ( - msg.author == last_message.author - and msg.content == last_message.content - ) - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py deleted file mode 100644 index b091bd9d7..000000000 --- a/tests/bot/rules/test_links.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Iterable - -from bot.rules import links -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_links: int) -> MockMessage: - """Makes a message with `total_links` links.""" - content = " ".join(["https://pydis.com"] * total_links) - return MockMessage(author=author, content=content) - - -class LinksTests(RuleTest): - """Tests applying the `links` rule.""" - - def setUp(self): - self.apply = links.apply - self.config = { - "max": 2, - "interval": 10 - } - - async def test_links_within_limit(self): - """Messages with an allowed amount of links.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 2)], - [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 - [make_msg("bob", 1), make_msg("bob", 1)], - [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count - ) - - await self.run_allowed(cases) - - async def test_links_exceeding_limit(self): - """Messages with a a higher than allowed amount of links.""" - cases = ( - DisallowedCase( - [make_msg("bob", 1), make_msg("bob", 2)], - ("bob",), - 3 - ), - DisallowedCase( - [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], - ("alice",), - 3 - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], - ("alice",), - 3 - ) - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py deleted file mode 100644 index e1f904917..000000000 --- a/tests/bot/rules/test_mentions.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Iterable, Optional - -import discord - -from bot.rules import mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage, MockMessageReference - - -def make_msg( - author: str, - total_user_mentions: int, - total_bot_mentions: int = 0, - *, - reference: Optional[MockMessageReference] = None -) -> MockMessage: - """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" - user_mentions = [MockMember() for _ in range(total_user_mentions)] - bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - - mentions = user_mentions + bot_mentions - if reference is not None: - # For the sake of these tests we assume that all references are mentions. - mentions.append(reference.resolved.author) - msg_type = discord.MessageType.reply - else: - msg_type = discord.MessageType.default - - return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) - - -class TestMentions(RuleTest): - """Tests applying the `mentions` antispam rule.""" - - def setUp(self): - self.apply = mentions.apply - self.config = { - "max": 2, - "interval": 10, - } - - async def test_mentions_within_limit(self): - """Messages with an allowed amount of mentions.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 2)], - [make_msg("bob", 1), make_msg("bob", 1)], - [make_msg("bob", 1), make_msg("alice", 2)], - ) - - await self.run_allowed(cases) - - async def test_mentions_exceeding_limit(self): - """Messages with a higher than allowed amount of mentions.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], - ("alice",), - 3, - ), - DisallowedCase( - [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], - ("bob",), - 4, - ), - DisallowedCase( - [make_msg("bob", 3, 1)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob", 3, reference=MockMessageReference())], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], - ("bob",), - 3 - ) - ) - - await self.run_disallowed(cases) - - async def test_ignore_bot_mentions(self): - """Messages with an allowed amount of mentions, also containing bot mentions.""" - cases = ( - [make_msg("bob", 0, 3)], - [make_msg("bob", 2, 1)], - [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], - [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] - ) - - await self.run_allowed(cases) - - async def test_ignore_reply_mentions(self): - """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" - cases = ( - [ - make_msg("bob", 2, reference=MockMessageReference()) - ], - [ - make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) - ], - [ - make_msg("bob", 2, reference=MockMessageReference()), - make_msg("bob", 0, reference=MockMessageReference()) - ], - [ - make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), - make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) - ] - ) - - await self.run_allowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py deleted file mode 100644 index e35377773..000000000 --- a/tests/bot/rules/test_newlines.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Iterable, List - -from bot.rules import newlines -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, newline_groups: List[int]) -> MockMessage: - """Init a MockMessage instance with `author` and content configured by `newline_groups". - - Configure content by passing a list of ints, where each int `n` will generate - a separate group of `n` newlines. - - Example: - newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" - """ - content = " ".join("\n" * n for n in newline_groups) - return MockMessage(author=author, content=content) - - -class TotalNewlinesRuleTests(RuleTest): - """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" - - def setUp(self): - self.apply = newlines.apply - self.config = { - "max": 5, # Max sum of newlines in relevant messages - "max_consecutive": 3, # Max newlines in one group, in one message - "interval": 10, - } - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("alice", [])], # Single message with no newlines - [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages - [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author - [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_total(self): - """Cases which violate the rule by having too many newlines in total.""" - cases = ( - DisallowedCase( # Alice sends a total of 6 newlines (disallowed) - [make_msg("alice", [2, 2]), make_msg("alice", [2])], - ("alice",), - 6, - ), - DisallowedCase( # Here we test that only alice's newlines count in the sum - [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], - ("alice",), - 7, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_author = case.recent_messages[0].author - return tuple(msg for msg in case.recent_messages if msg.author == last_author) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} newlines in {self.config['interval']}s" - - -class GroupNewlinesRuleTests(RuleTest): - """ - Tests the `newlines` antispam rule against max consecutive newline violations. - - As these violations yield a different error report, they require a different - `get_report` implementation. - """ - - def setUp(self): - self.apply = newlines.apply - self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - - async def test_disallows_messages_consecutive(self): - """Cases which violate the rule due to having too many consecutive newlines.""" - cases = ( - DisallowedCase( # Bob sends a group of newlines too large - [make_msg("bob", [4])], - ("bob",), - 4, - ), - DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed) - [make_msg("alice", [1]), make_msg("alice", [4])], - ("alice",), - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_author = case.recent_messages[0].author - return tuple(msg for msg in case.recent_messages if msg.author == last_author) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py deleted file mode 100644 index 26c05d527..000000000 --- a/tests/bot/rules/test_role_mentions.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Iterable - -from bot.rules import role_mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_mentions: int) -> MockMessage: - """Build a MockMessage instance with `n_mentions` role mentions.""" - return MockMessage(author=author, role_mentions=[None] * n_mentions) - - -class RoleMentionsRuleTests(RuleTest): - """Tests for the `role_mentions` antispam rule.""" - - def setUp(self): - self.apply = role_mentions.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of role mentions within limit.""" - cases = ( - [make_msg("bob", 2)], - [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with more than the allowed amount of role mentions.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], - ("alice",), - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/helpers.py b/tests/helpers.py index 1a71f210a..020f1aee5 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A MagicMock subclass to mock TextChannel objects. + A MagicMock subclass to mock DMChannel objects. - Instances of this class will follow the specifications of `discord.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.DMChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} + default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) @@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel( class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id)} - super().__init__(**collections.ChainMap(default_kwargs, kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) # Create a Message instance to get a realistic MagicMock of `discord.Message` |