diff options
Diffstat (limited to 'tests')
24 files changed, 712 insertions, 1507 deletions
| diff --git a/tests/_autospec.py b/tests/_autospec.py index ee2fc1973..ecff6bcbe 100644 --- a/tests/_autospec.py +++ b/tests/_autospec.py @@ -1,5 +1,6 @@  import contextlib  import functools +import pkgutil  import unittest.mock  from typing import Callable @@ -51,7 +52,7 @@ def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs)      # Import the target if it's a string.      # This is to support both object and string targets like patch.multiple.      if type(target) is str: -        target = unittest.mock._importer(target) +        target = pkgutil.resolve_name(target)      def decorator(func):          for attribute in attributes: diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filtering/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/bot/exts/filters/__init__.py +++ b/tests/bot/exts/filtering/__init__.py diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py new file mode 100644 index 000000000..a5cddf8d9 --- /dev/null +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -0,0 +1,276 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock, patch + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.unique import discord_token +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token +from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec + + +class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): +    """Tests the DiscordTokenFilter class.""" + +    def setUp(self): +        """Adds the filter, a bot, and a message to the instance for usage in tests.""" +        now = arrow.utcnow().timestamp() +        self.filter = DiscordTokenFilter({ +            "id": 1, +            "content": "discord_token", +            "description": None, +            "settings": {}, +            "additional_settings": {}, +            "created_at": now, +            "updated_at": now +        }) + +        self.msg = MockMessage(id=555, content="hello world") +        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg) + +    def test_extract_user_id_valid(self): +        """Should consider user IDs valid if they decode into an integer ID.""" +        id_pairs = ( +            ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), +            ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), +            ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), +        ) + +        for token_id, user_id in id_pairs: +            with self.subTest(token_id=token_id): +                result = DiscordTokenFilter.extract_user_id(token_id) +                self.assertEqual(result, user_id) + +    def test_extract_user_id_invalid(self): +        """Should consider non-digit and non-ASCII IDs invalid.""" +        ids = ( +            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), +            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), +            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), +            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), +            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for user_id, msg in ids: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.extract_user_id(user_id) +                self.assertIsNone(result) + +    def test_is_valid_timestamp_valid(self): +        """Should consider timestamps valid if they're greater than the Discord epoch.""" +        timestamps = ( +            "XsyRkw", +            "Xrim9Q", +            "XsyR-w", +            "XsySD_", +            "Dn9r_A", +        ) + +        for timestamp in timestamps: +            with self.subTest(timestamp=timestamp): +                result = DiscordTokenFilter.is_valid_timestamp(timestamp) +                self.assertTrue(result) + +    def test_is_valid_timestamp_invalid(self): +        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" +        timestamps = ( +            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), +            ("ew", "123"), +            ("AoIKgA", "42076800"), +            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), +            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), +        ) + +        for timestamp, msg in timestamps: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.is_valid_timestamp(timestamp) +                self.assertFalse(result) + +    def test_is_valid_hmac_valid(self): +        """Should consider an HMAC valid if it has at least 3 unique characters.""" +        valid_hmacs = ( +            "VXmErH7j511turNpfURmb0rVNm8", +            "Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "sJf6omBPORBPju3WJEIAcwW9Zds", +            "s45jqDV_Iisn-symw0yDRrk_jf4", +        ) + +        for hmac in valid_hmacs: +            with self.subTest(msg=hmac): +                result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) +                self.assertTrue(result) + +    def test_is_invalid_hmac_invalid(self): +        """Should consider an HMAC invalid if has fewer than 3 unique characters.""" +        invalid_hmacs = ( +            ("xxxxxxxxxxxxxxxxxx", "Single character"), +            ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), +            ("ASFasfASFasfASFASsf", "Three characters alternating-case"), +            ("asdasdasdasdasdasdasd", "Three characters one case"), +        ) + +        for hmac, msg in invalid_hmacs: +            with self.subTest(msg=msg): +                result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) +                self.assertFalse(result) + +    async def test_no_trigger_when_no_token(self): +        """False should be returned if the message doesn't contain a Discord token.""" +        return_value = await self.filter.triggered_on(self.ctx) + +        self.assertFalse(return_value) + +    @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") +    def test_find_token_valid_match( +        self, +        token_re, +        token_cls, +        extract_user_id, +        is_valid_timestamp, +        is_maybe_valid_hmac, +    ): +        """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" +        matches = [ +            mock.create_autospec(Match, spec_set=True, instance=True), +            mock.create_autospec(Match, spec_set=True, instance=True), +        ] +        tokens = [ +            mock.create_autospec(Token, spec_set=True, instance=True), +            mock.create_autospec(Token, spec_set=True, instance=True), +        ] + +        token_re.finditer.return_value = matches +        token_cls.side_effect = tokens +        extract_user_id.side_effect = (None, True)  # The 1st match will be invalid, 2nd one valid. +        is_valid_timestamp.return_value = True +        is_maybe_valid_hmac.return_value = True + +        return_value = DiscordTokenFilter.find_token_in_message(self.msg) + +        self.assertEqual(tokens[1], return_value) + +    @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") +    def test_find_token_invalid_matches( +        self, +        token_re, +        token_cls, +        extract_user_id, +        is_valid_timestamp, +        is_maybe_valid_hmac, +    ): +        """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" +        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] +        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) +        extract_user_id.return_value = None +        is_valid_timestamp.return_value = False +        is_maybe_valid_hmac.return_value = False + +        return_value = DiscordTokenFilter.find_token_in_message(self.msg) + +        self.assertIsNone(return_value) + +    def test_regex_invalid_tokens(self): +        """Messages without anything looking like a token are not matched.""" +        tokens = ( +            "", +            "lemon wins", +            "..", +            "x.y", +            "x.y.", +            ".y.z", +            ".y.", +            "..z", +            "x..z", +            " . . ", +            "\n.\n.\n", +            "hellö.world.bye", +            "base64.nötbåse64.morebase64", +            "19jd3J.dfkm3d.€víł§tüff", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = discord_token.TOKEN_RE.findall(token) +                self.assertEqual(len(results), 0) + +    def test_regex_valid_tokens(self): +        """Messages that look like tokens should be matched.""" +        # Don't worry, these tokens have been invalidated. +        tokens = ( +            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", +            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", +            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", +            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", +        ) + +        for token in tokens: +            with self.subTest(token=token): +                results = discord_token.TOKEN_RE.fullmatch(token) +                self.assertIsNotNone(results, f"{token} was not matched by the regex") + +    def test_regex_matches_multiple_valid(self): +        """Should support multiple matches in the middle of a string.""" +        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" +        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" +        message = f"garbage {token_1} hello {token_2} world" + +        results = discord_token.TOKEN_RE.finditer(message) +        results = [match[0] for match in results] +        self.assertCountEqual((token_1, token_2), results) + +    @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE") +    def test_format_log_message(self, log_message): +        """Should correctly format the log message with info from the message and token.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        log_message.format.return_value = "Howdy" + +        return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token) + +        self.assertEqual(return_value, log_message.format.return_value) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE") +    @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member") +    async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message): +        """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        unknown_user_log_message.format.return_value = " Partner" +        get_or_fetch_member.return_value = None + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") +    async def test_format_userid_log_message_bot(self, known_user_log_message): +        """Should correctly format the user ID portion when the ID belongs to a known bot.""" +        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        known_user_log_message.format.return_value = " Partner" + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) + +    @patch("bot.instance", MockBot()) +    @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") +    async def test_format_log_message_user_token_user(self, user_token_message): +        """Should correctly format the user ID portion when the ID belongs to a known user.""" +        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") +        user_token_message.format.return_value = "Partner" + +        return_value = await DiscordTokenFilter.format_userid_log_message(token) + +        self.assertEqual(return_value, (user_token_message.format.return_value, True)) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py new file mode 100644 index 000000000..827d267d2 --- /dev/null +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, patch + +import arrow + +from bot.constants import Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import extension +from bot.exts.filtering._filter_lists.extension import ExtensionsList +from bot.exts.filtering._filter_lists.filter_list import ListType +from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel + +BOT = MockBot() + + +class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): +    """Test the ExtensionsList class.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.filter_list = ExtensionsList(MagicMock()) +        now = arrow.utcnow().timestamp() +        filters = [] +        self.whitelist = [".first", ".second", ".third"] +        for i, filter_content in enumerate(self.whitelist, start=1): +            filters.append({ +                "id": i, "content": filter_content, "description": None, "settings": {}, +                "additional_settings": {}, "created_at": now, "updated_at": now  # noqa: P103 +            }) +        self.filter_list.add_list({ +            "id": 1, +            "list_type": 1, +            "created_at": now, +            "updated_at": now, +            "settings": {}, +            "filters": filters +        }) + +        self.message = MockMessage() +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message) + +    @patch("bot.instance", BOT) +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" +        attachment = MockAttachment(filename="python.first") +        ctx = self.ctx.replace(attachments=[attachment]) + +        result = await self.filter_list.actions_for(ctx) + +        self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) + +    @patch("bot.instance", BOT) +    async def test_message_without_attachment(self): +        """Messages without attachments should return no triggers, messages, or actions.""" +        result = await self.filter_list.actions_for(self.ctx) + +        self.assertEqual(result, (None, [], {})) + +    @patch("bot.instance", BOT) +    async def test_message_with_illegal_extension(self): +        """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" +        attachment = MockAttachment(filename="python.disallowed") +        ctx = self.ctx.replace(attachments=[attachment]) + +        result = await self.filter_list.actions_for(ctx) + +        self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) + +    @patch("bot.instance", BOT) +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site.""" +        attachment = MockAttachment(filename="python.py") +        ctx = self.ctx.replace(attachments=[attachment]) + +        await self.filter_list.actions_for(ctx) + +        self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + +    @patch("bot.instance", BOT) +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt/.json/.csv file should result in the correct embed.""" +        test_values = ( +            ("text", ".txt"), +            ("json", ".json"), +            ("csv", ".csv"), +        ) + +        for file_name, disallowed_extension in test_values: +            with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): + +                attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") +                ctx = self.ctx.replace(attachments=[attachment]) + +                await self.filter_list.actions_for(ctx) + +                self.assertEqual( +                    ctx.dm_embed, +                    extension.TXT_EMBED_DESCRIPTION.format( +                        blocked_extension=disallowed_extension, +                    ) +                ) + +    @patch("bot.instance", BOT) +    async def test_other_disallowed_extension_embed_description(self): +        """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        ctx = self.ctx.replace(attachments=[attachment]) + +        await self.filter_list.actions_for(ctx) +        meta_channel = BOT.get_channel(Channels.meta) + +        self.assertEqual( +            ctx.dm_embed, +            extension.DISALLOWED_EMBED_DESCRIPTION.format( +                joined_whitelist=", ".join(self.whitelist), +                joined_blacklist=".disallowed", +                meta_channel_mention=meta_channel.mention +            ) +        ) + +    @patch("bot.instance", BOT) +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (self.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], ["`.disallowed`"]), +            ([".disallowed"], ["`.disallowed`"]), +            ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) +                result = await self.filter_list.actions_for(ctx) +                self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py new file mode 100644 index 000000000..5a289c1cf --- /dev/null +++ b/tests/bot/exts/filtering/test_settings.py @@ -0,0 +1,20 @@ +import unittest + +import bot.exts.filtering._settings +from bot.exts.filtering._settings import create_settings + + +class FilterTests(unittest.TestCase): +    """Test functionality of the Settings class and its subclasses.""" + +    def test_create_settings_returns_none_for_empty_data(self): +        """`create_settings` should return a tuple of two Nones when passed an empty dict.""" +        result = create_settings({}) + +        self.assertEqual(result, (None, None)) + +    def test_unrecognized_entry_makes_a_warning(self): +        """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" +        create_settings({"abcd": {}}) + +        self.assertIn("abcd", bot.exts.filtering._settings._already_warned) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py new file mode 100644 index 000000000..3ae0b5ab5 --- /dev/null +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -0,0 +1,218 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( +    Infraction, InfractionAndNotification, InfractionDuration +) +from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM +from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel + + +class FilterTests(unittest.TestCase): +    """Test functionality of the Settings class and its subclasses.""" + +    def setUp(self) -> None: +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        message = MockMessage(author=member, channel=channel) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + +    def test_role_bypass_is_off_for_user_without_roles(self): +        """The role bypass should trigger when a user has no roles.""" +        member = MockMember() +        self.ctx.author = member +        bypass_entry = RoleBypass(bypass_roles=["123"]) + +        result = bypass_entry.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_role_bypass_is_on_for_a_user_with_the_right_role(self): +        """The role bypass should not trigger when the user has one of its roles.""" +        cases = ( +            ([123], ["123"]), +            ([123, 234], ["123"]), +            ([123], ["123", "234"]), +            ([123, 234], ["123", "234"]) +        ) + +        for user_role_ids, bypasses in cases: +            with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses): +                user_roles = [MockRole(id=role_id) for role_id in user_role_ids] +                member = MockMember(roles=user_roles) +                self.ctx.author = member +                bypass_entry = RoleBypass(bypass_roles=bypasses) + +                result = bypass_entry.triggers_on(self.ctx) + +                self.assertFalse(result) + +    def test_context_doesnt_trigger_for_empty_channel_scope(self): +        """A filter is enabled for all channels by default.""" +        channel = MockTextChannel() +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_doesnt_trigger_for_disabled_channel(self): +        """A filter shouldn't trigger if it's been disabled in the channel.""" +        channel = MockTextChannel(id=123) +        scope = ChannelScope( +            disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_doesnt_trigger_in_disabled_category(self): +        """A filter shouldn't trigger if it's been disabled in the category.""" +        channel = MockTextChannel(category=MockCategoryChannel(id=456)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_triggers_in_enabled_channel_in_disabled_category(self): +        """A filter should trigger in an enabled channel even if it's been disabled in the category.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_triggers_inside_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertTrue(result) + +    def test_context_doesnt_trigger_outside_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self): +        """A filter shouldn't trigger outside enabled categories, if there are any.""" +        channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) +        scope = ChannelScope( +            disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"] +        ) +        self.ctx.channel = channel + +        result = scope.triggers_on(self.ctx) + +        self.assertFalse(result) + +    def test_filtering_dms_when_necessary(self): +        """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" +        cases = ( +            (True, MockDMChannel(), True), +            (False, MockDMChannel(), False), +            (True, MockTextChannel(), True), +            (False, MockTextChannel(), True) +        ) + +        for apply_in_dms, channel, expected in cases: +            with self.subTest(apply_in_dms=apply_in_dms, channel=channel): +                filter_dms = FilterDM(filter_dm=apply_in_dms) +                self.ctx.channel = channel + +                result = filter_dms.triggers_on(self.ctx) + +                self.assertEqual(expected, result) + +    def test_infraction_merge_of_same_infraction_type(self): +        """When both infractions are of the same type, the one with the longer duration wins.""" +        infraction1 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="hi", +            infraction_duration=InfractionDuration(10), +            dm_content="how", +            dm_embed="what is", +            infraction_channel=0 +        ) +        infraction2 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="there", +            infraction_duration=InfractionDuration(20), +            dm_content="are you", +            dm_embed="your name", +            infraction_channel=0 +        ) + +        result = infraction1.union(infraction2) + +        self.assertDictEqual( +            result.dict(), +            { +                "infraction_type": Infraction.TIMEOUT, +                "infraction_reason": "there", +                "infraction_duration": InfractionDuration(20.0), +                "dm_content": "are you", +                "dm_embed": "your name", +                "infraction_channel": 0 +            } +        ) + +    def test_infraction_merge_of_different_infraction_types(self): +        """If there are two different infraction types, the one higher up the hierarchy should be picked.""" +        infraction1 = InfractionAndNotification( +            infraction_type="TIMEOUT", +            infraction_reason="hi", +            infraction_duration=InfractionDuration(20), +            dm_content="", +            dm_embed="", +            infraction_channel=0 +        ) +        infraction2 = InfractionAndNotification( +            infraction_type="BAN", +            infraction_reason="", +            infraction_duration=InfractionDuration(10), +            dm_content="there", +            dm_embed="", +            infraction_channel=0 +        ) + +        result = infraction1.union(infraction2) + +        self.assertDictEqual( +            result.dict(), +            { +                "infraction_type": Infraction.BAN, +                "infraction_reason": "", +                "infraction_duration": InfractionDuration(10), +                "dm_content": "there", +                "dm_embed": "", +                "infraction_channel": 0 +            } +        ) diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py new file mode 100644 index 000000000..03fa6b4b9 --- /dev/null +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -0,0 +1,49 @@ +import unittest + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class TokenFilterTests(unittest.IsolatedAsyncioTestCase): +    """Test functionality of the token filter.""" + +    def setUp(self) -> None: +        member = MockMember(id=123) +        channel = MockTextChannel(id=345) +        message = MockMessage(author=member, channel=channel) +        self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + +    async def test_token_filter_triggers(self): +        """The filter should evaluate to True only if its token is found in the context content.""" +        test_cases = ( +            (r"hi", "oh hi there", True), +            (r"hi", "goodbye", False), +            (r"bla\d{2,4}", "bla18", True), +            (r"bla\d{2,4}", "bla1", False), +            # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 +            (r"TOKEN", "https://google.com TOKEN", True), +            (r"TOKEN", "https://google.com something else", False) +        ) +        now = arrow.utcnow().timestamp() + +        for pattern, content, expected in test_cases: +            with self.subTest( +                pattern=pattern, +                content=content, +                expected=expected, +            ): +                filter_ = TokenFilter({ +                    "id": 1, +                    "content": pattern, +                    "description": None, +                    "settings": {}, +                    "additional_settings": {}, +                    "created_at": now, +                    "updated_at": now +                }) +                self.ctx.content = content +                result = await filter_.triggered_on(self.ctx) +                self.assertEqual(result, expected) diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py deleted file mode 100644 index 7282334e2..000000000 --- a/tests/bot/exts/filters/test_antimalware.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.constants import Channels, STAFF_ROLES -from bot.exts.filters import antimalware -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): -    """Test the AntiMalware cog.""" - -    def setUp(self): -        """Sets up fresh objects for each test.""" -        self.bot = MockBot() -        self.bot.filter_list_cache = { -            "FILE_FORMAT.True": { -                ".first": {}, -                ".second": {}, -                ".third": {}, -            } -        } -        self.cog = antimalware.AntiMalware(self.bot) -        self.message = MockMessage() -        self.message.webhook_id = None -        self.message.author.bot = None -        self.whitelist = [".first", ".second", ".third"] - -    async def test_message_with_allowed_attachment(self): -        """Messages with allowed extensions should not be deleted""" -        attachment = MockAttachment(filename="python.first") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) -        self.message.delete.assert_not_called() - -    async def test_message_without_attachment(self): -        """Messages without attachments should result in no action.""" -        await self.cog.on_message(self.message) -        self.message.delete.assert_not_called() - -    async def test_direct_message_with_attachment(self): -        """Direct messages should have no action taken.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.guild = None - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_webhook_message_with_illegal_extension(self): -        """A webhook message containing an illegal extension should be ignored.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.webhook_id = 697140105563078727 -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_bot_message_with_illegal_extension(self): -        """A bot message containing an illegal extension should be ignored.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.author.bot = 409107086526644234 -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_message_with_illegal_extension_gets_deleted(self): -        """A message containing an illegal extension should send an embed.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_called_once() - -    async def test_message_send_by_staff(self): -        """A message send by a member of staff should be ignored.""" -        staff_role = MockRole(id=STAFF_ROLES[0]) -        self.message.author.roles.append(staff_role) -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        await self.cog.on_message(self.message) - -        self.message.delete.assert_not_called() - -    async def test_python_file_redirect_embed_description(self): -        """A message containing a .py file should result in an embed redirecting the user to our paste site""" -        attachment = MockAttachment(filename="python.py") -        self.message.attachments = [attachment] -        self.message.channel.send = AsyncMock() - -        await self.cog.on_message(self.message) -        self.message.channel.send.assert_called_once() -        args, kwargs = self.message.channel.send.call_args -        embed = kwargs.pop("embed") - -        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - -    async def test_txt_file_redirect_embed_description(self): -        """A message containing a .txt/.json/.csv file should result in the correct embed.""" -        test_values = ( -            ("text", ".txt"), -            ("json", ".json"), -            ("csv", ".csv"), -        ) - -        for file_name, disallowed_extension in test_values: -            with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): - -                attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") -                self.message.attachments = [attachment] -                self.message.channel.send = AsyncMock() -                antimalware.TXT_EMBED_DESCRIPTION = Mock() -                antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - -                await self.cog.on_message(self.message) -                self.message.channel.send.assert_called_once() -                args, kwargs = self.message.channel.send.call_args -                embed = kwargs.pop("embed") -                cmd_channel = self.bot.get_channel(Channels.bot_commands) - -                self.assertEqual( -                    embed.description, -                    antimalware.TXT_EMBED_DESCRIPTION.format.return_value -                ) -                antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with( -                    blocked_extension=disallowed_extension, -                    cmd_channel_mention=cmd_channel.mention -                ) - -    async def test_other_disallowed_extension_embed_description(self): -        """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.channel.send = AsyncMock() -        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() -        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - -        await self.cog.on_message(self.message) -        self.message.channel.send.assert_called_once() -        args, kwargs = self.message.channel.send.call_args -        embed = kwargs.pop("embed") -        meta_channel = self.bot.get_channel(Channels.meta) - -        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) -        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( -            joined_whitelist=", ".join(self.whitelist), -            blocked_extensions_str=".disallowed", -            meta_channel_mention=meta_channel.mention -        ) - -    async def test_removing_deleted_message_logs(self): -        """Removing an already deleted message logs the correct message""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] -        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - -        with self.assertLogs(logger=antimalware.log, level="INFO"): -            await self.cog.on_message(self.message) -        self.message.delete.assert_called_once() - -    async def test_message_with_illegal_attachment_logs(self): -        """Deleting a message with an illegal attachment should result in a log.""" -        attachment = MockAttachment(filename="python.disallowed") -        self.message.attachments = [attachment] - -        with self.assertLogs(logger=antimalware.log, level="INFO"): -            await self.cog.on_message(self.message) - -    async def test_get_disallowed_extensions(self): -        """The return value should include all non-whitelisted extensions.""" -        test_values = ( -            ([], []), -            (self.whitelist, []), -            ([".first"], []), -            ([".first", ".disallowed"], [".disallowed"]), -            ([".disallowed"], [".disallowed"]), -            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), -        ) - -        for extensions, expected_disallowed_extensions in test_values: -            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): -                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] -                disallowed_extensions = self.cog._get_disallowed_extensions(self.message) -                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): -    """Tests setup of the `AntiMalware` cog.""" - -    async def test_setup(self): -        """Setup of the extension should call add_cog.""" -        bot = MockBot() -        await antimalware.setup(bot) -        bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py deleted file mode 100644 index 6a0e4fded..000000000 --- a/tests/bot/exts/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.exts.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): -    """Tests validation of the antispam cog configuration.""" - -    def test_default_antispam_config_is_valid(self): -        """The default antispam configuration is valid.""" -        validation_errors = antispam.validate_config() -        self.assertEqual(validation_errors, {}) - -    def test_unknown_rule_returns_error(self): -        """Configuring an unknown rule returns an error.""" -        self.assertEqual( -            antispam.validate_config({'invalid-rule': {}}), -            {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} -        ) - -    def test_missing_keys_returns_error(self): -        """Not configuring required keys returns an error.""" -        keys = (('interval', 'max'), ('max', 'interval')) -        for configured_key, unconfigured_key in keys: -            with self.subTest( -                configured_key=configured_key, -                unconfigured_key=unconfigured_key -            ): -                config = {'burst': {configured_key: 10}} -                error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - -                self.assertEqual( -                    antispam.validate_config(config), -                    {'burst': error} -                ) diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py deleted file mode 100644 index e47cf627b..000000000 --- a/tests/bot/exts/filters/test_filtering.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot.exts.filters import filtering -from tests.helpers import MockBot, autospec - - -class FilteringCogTests(unittest.IsolatedAsyncioTestCase): -    """Tests the `Filtering` cog.""" - -    def setUp(self): -        """Instantiate the bot and cog.""" -        self.bot = MockBot() -        with patch("pydis_core.utils.scheduling.create_task", new=lambda task, **_: task.close()): -            self.cog = filtering.Filtering(self.bot) - -    @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) -    async def test_token_filter(self): -        """Ensure that a filter token is correctly detected in a message.""" -        messages = { -            "": False, -            "no matches": False, -            "TOKEN": True, - -            # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 -            "https://google.com TOKEN": True, -            "https://google.com something else": False, -        } - -        for message, match in messages.items(): -            with self.subTest(input=message, match=match): -                result, _ = await self.cog._has_watch_regex_match(message) - -                self.assertEqual( -                    match, -                    bool(result), -                    msg=f"Hit was {'expected' if match else 'not expected'} for this input." -                ) -                if result: -                    self.assertEqual("TOKEN", result.group()) diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py deleted file mode 100644 index c1f3762ac..000000000 --- a/tests/bot/exts/filters/test_token_remover.py +++ /dev/null @@ -1,409 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.exts.filters import token_remover -from bot.exts.filters.token_remover import Token, TokenRemover -from bot.exts.moderation.modlog import ModLog -from bot.utils.messages import format_user -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): -    """Tests the `TokenRemover` cog.""" - -    def setUp(self): -        """Adds the cog, a bot, and a message to the instance for usage in tests.""" -        self.bot = MockBot() -        self.cog = TokenRemover(bot=self.bot) - -        self.msg = MockMessage(id=555, content="hello world") -        self.msg.channel.mention = "#lemonade-stand" -        self.msg.guild.get_member.return_value.bot = False -        self.msg.guild.get_member.return_value.__str__.return_value = "Woody" -        self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) -        self.msg.author.display_avatar.url = "picture-lemon.png" - -    def test_extract_user_id_valid(self): -        """Should consider user IDs valid if they decode into an integer ID.""" -        id_pairs = ( -            ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), -            ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), -            ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), -        ) - -        for token_id, user_id in id_pairs: -            with self.subTest(token_id=token_id): -                result = TokenRemover.extract_user_id(token_id) -                self.assertEqual(result, user_id) - -    def test_extract_user_id_invalid(self): -        """Should consider non-digit and non-ASCII IDs invalid.""" -        ids = ( -            ("SGVsbG8gd29ybGQ", "non-digit ASCII"), -            ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), -            ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), -            ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), -            ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), -            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), -            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), -        ) - -        for user_id, msg in ids: -            with self.subTest(msg=msg): -                result = TokenRemover.extract_user_id(user_id) -                self.assertIsNone(result) - -    def test_is_valid_timestamp_valid(self): -        """Should consider timestamps valid if they're greater than the Discord epoch.""" -        timestamps = ( -            "XsyRkw", -            "Xrim9Q", -            "XsyR-w", -            "XsySD_", -            "Dn9r_A", -        ) - -        for timestamp in timestamps: -            with self.subTest(timestamp=timestamp): -                result = TokenRemover.is_valid_timestamp(timestamp) -                self.assertTrue(result) - -    def test_is_valid_timestamp_invalid(self): -        """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" -        timestamps = ( -            ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), -            ("ew", "123"), -            ("AoIKgA", "42076800"), -            ("{hello}[world]&(bye!)", "ASCII invalid Base64"), -            ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), -        ) - -        for timestamp, msg in timestamps: -            with self.subTest(msg=msg): -                result = TokenRemover.is_valid_timestamp(timestamp) -                self.assertFalse(result) - -    def test_is_valid_hmac_valid(self): -        """Should consider an HMAC valid if it has at least 3 unique characters.""" -        valid_hmacs = ( -            "VXmErH7j511turNpfURmb0rVNm8", -            "Ysnu2wacjaKs7qnoo46S8Dm2us8", -            "sJf6omBPORBPju3WJEIAcwW9Zds", -            "s45jqDV_Iisn-symw0yDRrk_jf4", -        ) - -        for hmac in valid_hmacs: -            with self.subTest(msg=hmac): -                result = TokenRemover.is_maybe_valid_hmac(hmac) -                self.assertTrue(result) - -    def test_is_invalid_hmac_invalid(self): -        """Should consider an HMAC invalid if has fewer than 3 unique characters.""" -        invalid_hmacs = ( -            ("xxxxxxxxxxxxxxxxxx", "Single character"), -            ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), -            ("ASFasfASFasfASFASsf", "Three characters alternating-case"), -            ("asdasdasdasdasdasdasd", "Three characters one case"), -        ) - -        for hmac, msg in invalid_hmacs: -            with self.subTest(msg=msg): -                result = TokenRemover.is_maybe_valid_hmac(hmac) -                self.assertFalse(result) - -    def test_mod_log_property(self): -        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" -        self.bot.get_cog.return_value = 'lemon' -        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) -        self.bot.get_cog.assert_called_once_with('ModLog') - -    async def test_on_message_edit_uses_on_message(self): -        """The edit listener should delegate handling of the message to the normal listener.""" -        self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - -        await self.cog.on_message_edit(MockMessage(), self.msg) -        self.cog.on_message.assert_awaited_once_with(self.msg) - -    @autospec(TokenRemover, "find_token_in_message", "take_action") -    async def test_on_message_takes_action(self, find_token_in_message, take_action): -        """Should take action if a valid token is found when a message is sent.""" -        cog = TokenRemover(self.bot) -        found_token = "foobar" -        find_token_in_message.return_value = found_token - -        await cog.on_message(self.msg) - -        find_token_in_message.assert_called_once_with(self.msg) -        take_action.assert_awaited_once_with(cog, self.msg, found_token) - -    @autospec(TokenRemover, "find_token_in_message", "take_action") -    async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): -        """Shouldn't take action if a valid token isn't found when a message is sent.""" -        cog = TokenRemover(self.bot) -        find_token_in_message.return_value = False - -        await cog.on_message(self.msg) - -        find_token_in_message.assert_called_once_with(self.msg) -        take_action.assert_not_awaited() - -    @autospec(TokenRemover, "find_token_in_message") -    async def test_on_message_ignores_dms_bots(self, find_token_in_message): -        """Shouldn't parse a message if it is a DM or authored by a bot.""" -        cog = TokenRemover(self.bot) -        dm_msg = MockMessage(guild=None) -        bot_msg = MockMessage(author=MagicMock(bot=True)) - -        for msg in (dm_msg, bot_msg): -            await cog.on_message(msg) -            find_token_in_message.assert_not_called() - -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_no_matches(self, token_re): -        """None should be returned if the regex matches no tokens in a message.""" -        token_re.finditer.return_value = () - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertIsNone(return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") -    @autospec("bot.exts.filters.token_remover", "Token") -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_valid_match( -        self, -        token_re, -        token_cls, -        extract_user_id, -        is_valid_timestamp, -        is_maybe_valid_hmac, -    ): -        """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" -        matches = [ -            mock.create_autospec(Match, spec_set=True, instance=True), -            mock.create_autospec(Match, spec_set=True, instance=True), -        ] -        tokens = [ -            mock.create_autospec(Token, spec_set=True, instance=True), -            mock.create_autospec(Token, spec_set=True, instance=True), -        ] - -        token_re.finditer.return_value = matches -        token_cls.side_effect = tokens -        extract_user_id.side_effect = (None, True)  # The 1st match will be invalid, 2nd one valid. -        is_valid_timestamp.return_value = True -        is_maybe_valid_hmac.return_value = True - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertEqual(tokens[1], return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") -    @autospec("bot.exts.filters.token_remover", "Token") -    @autospec("bot.exts.filters.token_remover", "TOKEN_RE") -    def test_find_token_invalid_matches( -        self, -        token_re, -        token_cls, -        extract_user_id, -        is_valid_timestamp, -        is_maybe_valid_hmac, -    ): -        """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" -        token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] -        token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) -        extract_user_id.return_value = None -        is_valid_timestamp.return_value = False -        is_maybe_valid_hmac.return_value = False - -        return_value = TokenRemover.find_token_in_message(self.msg) - -        self.assertIsNone(return_value) -        token_re.finditer.assert_called_once_with(self.msg.content) - -    def test_regex_invalid_tokens(self): -        """Messages without anything looking like a token are not matched.""" -        tokens = ( -            "", -            "lemon wins", -            "..", -            "x.y", -            "x.y.", -            ".y.z", -            ".y.", -            "..z", -            "x..z", -            " . . ", -            "\n.\n.\n", -            "hellö.world.bye", -            "base64.nötbåse64.morebase64", -            "19jd3J.dfkm3d.€víł§tüff", -        ) - -        for token in tokens: -            with self.subTest(token=token): -                results = token_remover.TOKEN_RE.findall(token) -                self.assertEqual(len(results), 0) - -    def test_regex_valid_tokens(self): -        """Messages that look like tokens should be matched.""" -        # Don't worry, these tokens have been invalidated. -        tokens = ( -            "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", -            "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", -            "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", -            "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", -        ) - -        for token in tokens: -            with self.subTest(token=token): -                results = token_remover.TOKEN_RE.fullmatch(token) -                self.assertIsNotNone(results, f"{token} was not matched by the regex") - -    def test_regex_matches_multiple_valid(self): -        """Should support multiple matches in the middle of a string.""" -        token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" -        token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" -        message = f"garbage {token_1} hello {token_2} world" - -        results = token_remover.TOKEN_RE.finditer(message) -        results = [match[0] for match in results] -        self.assertCountEqual((token_1, token_2), results) - -    @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") -    def test_format_log_message(self, log_message): -        """Should correctly format the log message with info from the message and token.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        log_message.format.return_value = "Howdy" - -        return_value = TokenRemover.format_log_message(self.msg, token) - -        self.assertEqual(return_value, log_message.format.return_value) -        log_message.format.assert_called_once_with( -            author=format_user(self.msg.author), -            channel=self.msg.channel.mention, -            user_id=token.user_id, -            timestamp=token.timestamp, -            hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4", -        ) - -    @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE") -    async def test_format_userid_log_message_unknown(self, unknown_user_log_message,): -        """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        unknown_user_log_message.format.return_value = " Partner" -        msg = MockMessage(id=555, content="hello world") -        msg.guild.get_member.return_value = None -        msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found") - -        return_value = await TokenRemover.format_userid_log_message(msg, token) - -        self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) -        unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332) - -    @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") -    async def test_format_userid_log_message_bot(self, known_user_log_message): -        """Should correctly format the user ID portion when the ID belongs to a known bot.""" -        token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        known_user_log_message.format.return_value = " Partner" -        msg = MockMessage(id=555, content="hello world") -        msg.guild.get_member.return_value.__str__.return_value = "Sam" -        msg.guild.get_member.return_value.bot = True - -        return_value = await TokenRemover.format_userid_log_message(msg, token) - -        self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) - -        known_user_log_message.format.assert_called_once_with( -            user_id=472265943062413332, -            user_name="Sam", -            kind="BOT", -        ) - -    @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") -    async def test_format_log_message_user_token_user(self, user_token_message): -        """Should correctly format the user ID portion when the ID belongs to a known user.""" -        token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") -        user_token_message.format.return_value = "Partner" - -        return_value = await TokenRemover.format_userid_log_message(self.msg, token) - -        self.assertEqual(return_value, (user_token_message.format.return_value, True)) -        user_token_message.format.assert_called_once_with( -            user_id=467223230650777641, -            user_name="Woody", -            kind="USER", -        ) - -    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) -    @autospec("bot.exts.filters.token_remover", "log") -    @autospec(TokenRemover, "format_log_message", "format_userid_log_message") -    async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property): -        """Should delete the message and send a mod log.""" -        cog = TokenRemover(self.bot) -        mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) -        token = mock.create_autospec(Token, spec_set=True, instance=True) -        token.user_id = "no-id" -        log_msg = "testing123" -        userid_log_message = "userid-log-message" - -        mod_log_property.return_value = mod_log -        format_log_message.return_value = log_msg -        format_userid_log_message.return_value = (userid_log_message, True) - -        await cog.take_action(self.msg, token) - -        self.msg.delete.assert_called_once_with() -        self.msg.channel.send.assert_called_once_with( -            token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) -        ) - -        format_log_message.assert_called_once_with(self.msg, token) -        format_userid_log_message.assert_called_once_with(self.msg, token) -        logger.debug.assert_called_with(log_msg) -        self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - -        mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) -        mod_log.send_log_message.assert_called_once_with( -            icon_url=constants.Icons.token_removed, -            colour=Colour(constants.Colours.soft_red), -            title="Token removed!", -            text=log_msg + "\n" + userid_log_message, -            thumbnail=self.msg.author.display_avatar.url, -            channel_id=constants.Channels.mod_alerts, -            ping_everyone=True, -        ) - -    @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) -    async def test_take_action_delete_failure(self, mod_log_property): -        """Shouldn't send any messages if the token message can't be deleted.""" -        cog = TokenRemover(self.bot) -        mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) -        self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - -        token = mock.create_autospec(Token, spec_set=True, instance=True) -        await cog.take_action(self.msg, token) - -        self.msg.delete.assert_called_once_with() -        self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): -    """Tests for the token_remover extension.""" - -    @autospec("bot.exts.filters.token_remover", "TokenRemover") -    async def test_extension_setup(self, cog): -        """The TokenRemover cog should be added.""" -        bot = MockBot() -        await token_remover.setup(bot) - -        cog.assert_called_once_with(bot) -        bot.add_cog.assert_awaited_once() -        self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # Should not be called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.cog.upload_output = AsyncMock()  # This function isn't called          mocked_filter_cog = MagicMock() -        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) +        mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))          self.bot.get_cog.return_value = mocked_filter_cog          job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py deleted file mode 100644 index 0d570f5a3..000000000 --- a/tests/bot/rules/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest -from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple - -from tests.helpers import MockMessage - - -class DisallowedCase(NamedTuple): -    """Encapsulation for test cases expected to fail.""" -    recent_messages: List[MockMessage] -    culprits: Iterable[str] -    n_violations: int - - -class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): -    """ -    Abstract class for antispam rule test cases. - -    Tests for specific rules should inherit from `RuleTest` and implement -    `relevant_messages` and `get_report`. Each instance should also set the -    `apply` and `config` attributes as necessary. - -    The execution of test cases can then be delegated to the `run_allowed` -    and `run_disallowed` methods. -    """ - -    apply: Callable  # The tested rule's apply function -    config: Dict[str, int] - -    async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: -        """Run all `cases` against `self.apply` expecting them to pass.""" -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config, -            ): -                self.assertIsNone( -                    await self.apply(last_message, recent_messages, self.config) -                ) - -    async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: -        """Run all `cases` against `self.apply` expecting them to fail.""" -        for case in cases: -            recent_messages, culprits, n_violations = case -            last_message = recent_messages[0] -            relevant_messages = self.relevant_messages(case) -            desired_output = ( -                self.get_report(case), -                culprits, -                relevant_messages, -            ) - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                n_violations=n_violations, -                config=self.config, -            ): -                self.assertTupleEqual( -                    await self.apply(last_message, recent_messages, self.config), -                    desired_output, -                ) - -    @abstractmethod -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        """Give expected relevant messages for `case`.""" -        raise NotImplementedError  # pragma: no cover - -    @abstractmethod -    def get_report(self, case: DisallowedCase) -> str: -        """Give expected error report for `case`.""" -        raise NotImplementedError  # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py deleted file mode 100644 index d7e779221..000000000 --- a/tests/bot/rules/test_attachments.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Iterable - -from bot.rules import attachments -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_attachments: int) -> MockMessage: -    """Builds a message with `total_attachments` attachments.""" -    return MockMessage(author=author, attachments=list(range(total_attachments))) - - -class AttachmentRuleTests(RuleTest): -    """Tests applying the `attachments` antispam rule.""" - -    def setUp(self): -        self.apply = attachments.apply -        self.config = {"max": 5, "interval": 10} - -    async def test_allows_messages_without_too_many_attachments(self): -        """Messages without too many attachments are allowed as-is.""" -        cases = ( -            [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], -            [make_msg("bob", 2), make_msg("bob", 2)], -            [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_with_too_many_attachments(self): -        """Messages with too many attachments trigger the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], -                ("bob",), -                10, -            ), -            DisallowedCase( -                [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], -                ("bob",), -                6, -            ), -            DisallowedCase( -                [make_msg("alice", 6)], -                ("alice",), -                6, -            ), -            DisallowedCase( -                [make_msg("alice", 1) for _ in range(6)], -                ("alice",), -                6, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if ( -                msg.author == last_message.author -                and len(msg.attachments) > 0 -            ) -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py deleted file mode 100644 index 03682966b..000000000 --- a/tests/bot/rules/test_burst.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Iterable - -from bot.rules import burst -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: -    """ -    Init a MockMessage instance with author set to `author`. - -    This serves as a shorthand / alias to keep the test cases visually clean. -    """ -    return MockMessage(author=author) - - -class BurstRuleTests(RuleTest): -    """Tests the `burst` antispam rule.""" - -    def setUp(self): -        self.apply = burst.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("bob"), make_msg("bob")], -            [make_msg("bob"), make_msg("alice"), make_msg("bob")], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the amount of messages exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("bob")], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], -                ("bob",), -                3, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py deleted file mode 100644 index 3275143d5..000000000 --- a/tests/bot/rules/test_burst_shared.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable - -from bot.rules import burst_shared -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: -    """ -    Init a MockMessage instance with the passed arg. - -    This serves as a shorthand / alias to keep the test cases visually clean. -    """ -    return MockMessage(author=author) - - -class BurstSharedRuleTests(RuleTest): -    """Tests the `burst_shared` antispam rule.""" - -    def setUp(self): -        self.apply = burst_shared.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """ -        Cases that do not violate the rule. - -        There really isn't more to test here than a single case. -        """ -        cases = ( -            [make_msg("spongebob"), make_msg("patrick")], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the amount of messages exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("bob")], -                {"bob"}, -                3, -            ), -            DisallowedCase( -                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], -                {"bob", "alice"}, -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return case.recent_messages - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py deleted file mode 100644 index f1e3c76a7..000000000 --- a/tests/bot/rules/test_chars.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import chars -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_chars: int) -> MockMessage: -    """Build a message with arbitrary content of `n_chars` length.""" -    return MockMessage(author=author, content="A" * n_chars) - - -class CharsRuleTests(RuleTest): -    """Tests the `chars` antispam rule.""" - -    def setUp(self): -        self.apply = chars.apply -        self.config = { -            "max": 20,  # Max allowed sum of chars per user -            "interval": 10, -        } - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of chars within limit.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 20)], -            [make_msg("bob", 15), make_msg("alice", 15)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases where the total amount of chars exceeds the limit, triggering the rule.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 21)], -                ("bob",), -                21, -            ), -            DisallowedCase( -                [make_msg("bob", 15), make_msg("bob", 15)], -                ("bob",), -                30, -            ), -            DisallowedCase( -                [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], -                ("alice",), -                30, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py deleted file mode 100644 index 66c2d9f92..000000000 --- a/tests/bot/rules/test_discord_emojis.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Iterable - -from bot.rules import discord_emojis -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - -discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> -unicode_emoji = "🧪" - - -def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage: -    """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" -    return MockMessage(author=author, content=emoji * n_emojis) - - -class DiscordEmojisRuleTests(RuleTest): -    """Tests for the `discord_emojis` antispam rule.""" - -    def setUp(self): -        self.apply = discord_emojis.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of discord and unicode emojis within limit.""" -        cases = ( -            [make_msg("bob", 2)], -            [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], -            [make_msg("bob", 2, unicode_emoji)], -            [ -                make_msg("alice", 1, unicode_emoji), -                make_msg("bob", 2, unicode_emoji), -                make_msg("alice", 1, unicode_emoji) -            ], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with more than the allowed amount of discord and unicode emojis.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                ("alice",), -                4, -            ), -            DisallowedCase( -                [make_msg("bob", 3, unicode_emoji)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [ -                    make_msg("alice", 2, unicode_emoji), -                    make_msg("bob", 2, unicode_emoji), -                    make_msg("alice", 2, unicode_emoji) -                ], -                ("alice",), -                4 -            ) -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py deleted file mode 100644 index 9bd886a77..000000000 --- a/tests/bot/rules/test_duplicates.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import duplicates -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, content: str) -> MockMessage: -    """Give a MockMessage instance with `author` and `content` attrs.""" -    return MockMessage(author=author, content=content) - - -class DuplicatesRuleTests(RuleTest): -    """Tests the `duplicates` antispam rule.""" - -    def setUp(self): -        self.apply = duplicates.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("alice", "A"), make_msg("alice", "A")], -            [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")],  # Non-duplicate -            [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")],  # Different author -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with too many duplicate messages from the same author.""" -        cases = ( -            DisallowedCase( -                [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], -                ("alice",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], -                ("bob",), -                3,  # 4 duplicate messages, but only 3 from bob -            ), -            DisallowedCase( -                [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], -                ("bob",), -                3,  # 4 message from bob, but only 3 duplicates -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if ( -                msg.author == last_message.author -                and msg.content == last_message.content -            ) -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py deleted file mode 100644 index b091bd9d7..000000000 --- a/tests/bot/rules/test_links.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Iterable - -from bot.rules import links -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_links: int) -> MockMessage: -    """Makes a message with `total_links` links.""" -    content = " ".join(["https://pydis.com"] * total_links) -    return MockMessage(author=author, content=content) - - -class LinksTests(RuleTest): -    """Tests applying the `links` rule.""" - -    def setUp(self): -        self.apply = links.apply -        self.config = { -            "max": 2, -            "interval": 10 -        } - -    async def test_links_within_limit(self): -        """Messages with an allowed amount of links.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 2)], -            [make_msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 -            [make_msg("bob", 1), make_msg("bob", 1)], -            [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count -        ) - -        await self.run_allowed(cases) - -    async def test_links_exceeding_limit(self): -        """Messages with a a higher than allowed amount of links.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 1), make_msg("bob", 2)], -                ("bob",), -                3 -            ), -            DisallowedCase( -                [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], -                ("alice",), -                3 -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], -                ("alice",), -                3 -            ) -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py deleted file mode 100644 index e1f904917..000000000 --- a/tests/bot/rules/test_mentions.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Iterable, Optional - -import discord - -from bot.rules import mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage, MockMessageReference - - -def make_msg( -    author: str, -    total_user_mentions: int, -    total_bot_mentions: int = 0, -    *, -    reference: Optional[MockMessageReference] = None -) -> MockMessage: -    """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" -    user_mentions = [MockMember() for _ in range(total_user_mentions)] -    bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - -    mentions = user_mentions + bot_mentions -    if reference is not None: -        # For the sake of these tests we assume that all references are mentions. -        mentions.append(reference.resolved.author) -        msg_type = discord.MessageType.reply -    else: -        msg_type = discord.MessageType.default - -    return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) - - -class TestMentions(RuleTest): -    """Tests applying the `mentions` antispam rule.""" - -    def setUp(self): -        self.apply = mentions.apply -        self.config = { -            "max": 2, -            "interval": 10, -        } - -    async def test_mentions_within_limit(self): -        """Messages with an allowed amount of mentions.""" -        cases = ( -            [make_msg("bob", 0)], -            [make_msg("bob", 2)], -            [make_msg("bob", 1), make_msg("bob", 1)], -            [make_msg("bob", 1), make_msg("alice", 2)], -        ) - -        await self.run_allowed(cases) - -    async def test_mentions_exceeding_limit(self): -        """Messages with a higher than allowed amount of mentions.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], -                ("alice",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], -                ("bob",), -                4, -            ), -            DisallowedCase( -                [make_msg("bob", 3, 1)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 3, reference=MockMessageReference())], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], -                ("bob",), -                3 -            ) -        ) - -        await self.run_disallowed(cases) - -    async def test_ignore_bot_mentions(self): -        """Messages with an allowed amount of mentions, also containing bot mentions.""" -        cases = ( -            [make_msg("bob", 0, 3)], -            [make_msg("bob", 2, 1)], -            [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], -            [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] -        ) - -        await self.run_allowed(cases) - -    async def test_ignore_reply_mentions(self): -        """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" -        cases = ( -            [ -                make_msg("bob", 2, reference=MockMessageReference()) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference()), -                make_msg("bob", 0, reference=MockMessageReference()) -            ], -            [ -                make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), -                make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) -            ] -        ) - -        await self.run_allowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py deleted file mode 100644 index e35377773..000000000 --- a/tests/bot/rules/test_newlines.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Iterable, List - -from bot.rules import newlines -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, newline_groups: List[int]) -> MockMessage: -    """Init a MockMessage instance with `author` and content configured by `newline_groups". - -    Configure content by passing a list of ints, where each int `n` will generate -    a separate group of `n` newlines. - -    Example: -        newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" -    """ -    content = " ".join("\n" * n for n in newline_groups) -    return MockMessage(author=author, content=content) - - -class TotalNewlinesRuleTests(RuleTest): -    """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" - -    def setUp(self): -        self.apply = newlines.apply -        self.config = { -            "max": 5,  # Max sum of newlines in relevant messages -            "max_consecutive": 3,  # Max newlines in one group, in one message -            "interval": 10, -        } - -    async def test_allows_messages_within_limit(self): -        """Cases which do not violate the rule.""" -        cases = ( -            [make_msg("alice", [])],  # Single message with no newlines -            [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])],  # 5 newlines in 2 messages -            [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])],  # 5 newlines from each author -            [make_msg("bob", [1]), make_msg("alice", [5])],  # Alice breaks the rule, but only bob is relevant -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_total(self): -        """Cases which violate the rule by having too many newlines in total.""" -        cases = ( -            DisallowedCase(  # Alice sends a total of 6 newlines (disallowed) -                [make_msg("alice", [2, 2]), make_msg("alice", [2])], -                ("alice",), -                6, -            ), -            DisallowedCase(  # Here we test that only alice's newlines count in the sum -                [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], -                ("alice",), -                7, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_author = case.recent_messages[0].author -        return tuple(msg for msg in case.recent_messages if msg.author == last_author) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} newlines in {self.config['interval']}s" - - -class GroupNewlinesRuleTests(RuleTest): -    """ -    Tests the `newlines` antispam rule against max consecutive newline violations. - -    As these violations yield a different error report, they require a different -    `get_report` implementation. -    """ - -    def setUp(self): -        self.apply = newlines.apply -        self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - -    async def test_disallows_messages_consecutive(self): -        """Cases which violate the rule due to having too many consecutive newlines.""" -        cases = ( -            DisallowedCase(  # Bob sends a group of newlines too large -                [make_msg("bob", [4])], -                ("bob",), -                4, -            ), -            DisallowedCase(  # Alice sends 5 in total (allowed), but 4 in one group (disallowed) -                [make_msg("alice", [1]), make_msg("alice", [4])], -                ("alice",), -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_author = case.recent_messages[0].author -        return tuple(msg for msg in case.recent_messages if msg.author == last_author) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py deleted file mode 100644 index 26c05d527..000000000 --- a/tests/bot/rules/test_role_mentions.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Iterable - -from bot.rules import role_mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_mentions: int) -> MockMessage: -    """Build a MockMessage instance with `n_mentions` role mentions.""" -    return MockMessage(author=author, role_mentions=[None] * n_mentions) - - -class RoleMentionsRuleTests(RuleTest): -    """Tests for the `role_mentions` antispam rule.""" - -    def setUp(self): -        self.apply = role_mentions.apply -        self.config = {"max": 2, "interval": 10} - -    async def test_allows_messages_within_limit(self): -        """Cases with a total amount of role mentions within limit.""" -        cases = ( -            [make_msg("bob", 2)], -            [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], -        ) - -        await self.run_allowed(cases) - -    async def test_disallows_messages_beyond_limit(self): -        """Cases with more than the allowed amount of role mentions.""" -        cases = ( -            DisallowedCase( -                [make_msg("bob", 3)], -                ("bob",), -                3, -            ), -            DisallowedCase( -                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], -                ("alice",), -                4, -            ), -        ) - -        await self.run_disallowed(cases) - -    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: -        last_message = case.recent_messages[0] -        return tuple( -            msg -            for msg in case.recent_messages -            if msg.author == last_message.author -        ) - -    def get_report(self, case: DisallowedCase) -> str: -        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/helpers.py b/tests/helpers.py index 1a71f210a..020f1aee5 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data  class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """ -    A MagicMock subclass to mock TextChannel objects. +    A MagicMock subclass to mock DMChannel objects. -    Instances of this class will follow the specifications of `discord.TextChannel` instances. For +    Instances of this class will follow the specifications of `discord.DMChannel` instances. For      more information, see the `MockGuild` docstring.      """      spec_set = dm_channel_instance      def __init__(self, **kwargs) -> None: -        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} +        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None}          super().__init__(**collections.ChainMap(kwargs, default_kwargs)) @@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel(  class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id)} -        super().__init__(**collections.ChainMap(default_kwargs, kwargs)) +        super().__init__(**collections.ChainMap(kwargs, default_kwargs))  # Create a Message instance to get a realistic MagicMock of `discord.Message` | 
