From d2226cc067aebd61113c584ccf833d55cf227a2d Mon Sep 17 00:00:00 2001 From: mbaruh Date: Wed, 1 Dec 2021 22:15:31 +0200 Subject: Tear down the old filtering system Tests and dependent functionality in other extensions will be re-added later on. --- tests/bot/exts/filters/__init__.py | 0 tests/bot/exts/filters/test_antimalware.py | 202 ------------- tests/bot/exts/filters/test_antispam.py | 35 --- tests/bot/exts/filters/test_filtering.py | 40 --- tests/bot/exts/filters/test_security.py | 53 ---- tests/bot/exts/filters/test_token_remover.py | 409 --------------------------- tests/bot/rules/__init__.py | 76 ----- tests/bot/rules/test_attachments.py | 69 ----- tests/bot/rules/test_burst.py | 54 ---- tests/bot/rules/test_burst_shared.py | 57 ---- tests/bot/rules/test_chars.py | 64 ----- tests/bot/rules/test_discord_emojis.py | 73 ----- tests/bot/rules/test_duplicates.py | 64 ----- tests/bot/rules/test_links.py | 67 ----- tests/bot/rules/test_mentions.py | 83 ------ tests/bot/rules/test_newlines.py | 102 ------- tests/bot/rules/test_role_mentions.py | 55 ---- 17 files changed, 1503 deletions(-) delete mode 100644 tests/bot/exts/filters/__init__.py delete mode 100644 tests/bot/exts/filters/test_antimalware.py delete mode 100644 tests/bot/exts/filters/test_antispam.py delete mode 100644 tests/bot/exts/filters/test_filtering.py delete mode 100644 tests/bot/exts/filters/test_security.py delete mode 100644 tests/bot/exts/filters/test_token_remover.py delete mode 100644 tests/bot/rules/__init__.py delete mode 100644 tests/bot/rules/test_attachments.py delete mode 100644 tests/bot/rules/test_burst.py delete mode 100644 tests/bot/rules/test_burst_shared.py delete mode 100644 tests/bot/rules/test_chars.py delete mode 100644 tests/bot/rules/test_discord_emojis.py delete mode 100644 tests/bot/rules/test_duplicates.py delete mode 100644 tests/bot/rules/test_links.py delete mode 100644 tests/bot/rules/test_mentions.py delete mode 100644 tests/bot/rules/test_newlines.py delete mode 100644 tests/bot/rules/test_role_mentions.py (limited to 'tests') diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py deleted file mode 100644 index 7282334e2..000000000 --- a/tests/bot/exts/filters/test_antimalware.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, Mock - -from discord import NotFound - -from bot.constants import Channels, STAFF_ROLES -from bot.exts.filters import antimalware -from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole - - -class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): - """Test the AntiMalware cog.""" - - def setUp(self): - """Sets up fresh objects for each test.""" - self.bot = MockBot() - self.bot.filter_list_cache = { - "FILE_FORMAT.True": { - ".first": {}, - ".second": {}, - ".third": {}, - } - } - self.cog = antimalware.AntiMalware(self.bot) - self.message = MockMessage() - self.message.webhook_id = None - self.message.author.bot = None - self.whitelist = [".first", ".second", ".third"] - - async def test_message_with_allowed_attachment(self): - """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_message_without_attachment(self): - """Messages without attachments should result in no action.""" - await self.cog.on_message(self.message) - self.message.delete.assert_not_called() - - async def test_direct_message_with_attachment(self): - """Direct messages should have no action taken.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.guild = None - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_webhook_message_with_illegal_extension(self): - """A webhook message containing an illegal extension should be ignored.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.webhook_id = 697140105563078727 - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_bot_message_with_illegal_extension(self): - """A bot message containing an illegal extension should be ignored.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.author.bot = 409107086526644234 - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_message_with_illegal_extension_gets_deleted(self): - """A message containing an illegal extension should send an embed.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_called_once() - - async def test_message_send_by_staff(self): - """A message send by a member of staff should be ignored.""" - staff_role = MockRole(id=STAFF_ROLES[0]) - self.message.author.roles.append(staff_role) - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - await self.cog.on_message(self.message) - - self.message.delete.assert_not_called() - - async def test_python_file_redirect_embed_description(self): - """A message containing a .py file should result in an embed redirecting the user to our paste site""" - attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - - self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) - - async def test_txt_file_redirect_embed_description(self): - """A message containing a .txt/.json/.csv file should result in the correct embed.""" - test_values = ( - ("text", ".txt"), - ("json", ".json"), - ("csv", ".csv"), - ) - - for file_name, disallowed_extension in test_values: - with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): - - attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.TXT_EMBED_DESCRIPTION = Mock() - antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - cmd_channel = self.bot.get_channel(Channels.bot_commands) - - self.assertEqual( - embed.description, - antimalware.TXT_EMBED_DESCRIPTION.format.return_value - ) - antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with( - blocked_extension=disallowed_extension, - cmd_channel_mention=cmd_channel.mention - ) - - async def test_other_disallowed_extension_embed_description(self): - """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.channel.send = AsyncMock() - antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" - - await self.cog.on_message(self.message) - self.message.channel.send.assert_called_once() - args, kwargs = self.message.channel.send.call_args - embed = kwargs.pop("embed") - meta_channel = self.bot.get_channel(Channels.meta) - - self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) - antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( - joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", - meta_channel_mention=meta_channel.mention - ) - - async def test_removing_deleted_message_logs(self): - """Removing an already deleted message logs the correct message""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - self.message.delete.assert_called_once() - - async def test_message_with_illegal_attachment_logs(self): - """Deleting a message with an illegal attachment should result in a log.""" - attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] - - with self.assertLogs(logger=antimalware.log, level="INFO"): - await self.cog.on_message(self.message) - - async def test_get_disallowed_extensions(self): - """The return value should include all non-whitelisted extensions.""" - test_values = ( - ([], []), - (self.whitelist, []), - ([".first"], []), - ([".first", ".disallowed"], [".disallowed"]), - ([".disallowed"], [".disallowed"]), - ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), - ) - - for extensions, expected_disallowed_extensions in test_values: - with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog._get_disallowed_extensions(self.message) - self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) - - -class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase): - """Tests setup of the `AntiMalware` cog.""" - - async def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - await antimalware.setup(bot) - bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py deleted file mode 100644 index 6a0e4fded..000000000 --- a/tests/bot/exts/filters/test_antispam.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from bot.exts.filters import antispam - - -class AntispamConfigurationValidationTests(unittest.TestCase): - """Tests validation of the antispam cog configuration.""" - - def test_default_antispam_config_is_valid(self): - """The default antispam configuration is valid.""" - validation_errors = antispam.validate_config() - self.assertEqual(validation_errors, {}) - - def test_unknown_rule_returns_error(self): - """Configuring an unknown rule returns an error.""" - self.assertEqual( - antispam.validate_config({'invalid-rule': {}}), - {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} - ) - - def test_missing_keys_returns_error(self): - """Not configuring required keys returns an error.""" - keys = (('interval', 'max'), ('max', 'interval')) - for configured_key, unconfigured_key in keys: - with self.subTest( - configured_key=configured_key, - unconfigured_key=unconfigured_key - ): - config = {'burst': {configured_key: 10}} - error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" - - self.assertEqual( - antispam.validate_config(config), - {'burst': error} - ) diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py deleted file mode 100644 index bd26532f1..000000000 --- a/tests/bot/exts/filters/test_filtering.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -from unittest.mock import patch - -from bot.exts.filters import filtering -from tests.helpers import MockBot, autospec - - -class FilteringCogTests(unittest.IsolatedAsyncioTestCase): - """Tests the `Filtering` cog.""" - - def setUp(self): - """Instantiate the bot and cog.""" - self.bot = MockBot() - with patch("botcore.utils.scheduling.create_task", new=lambda task, **_: task.close()): - self.cog = filtering.Filtering(self.bot) - - @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"]) - async def test_token_filter(self): - """Ensure that a filter token is correctly detected in a message.""" - messages = { - "": False, - "no matches": False, - "TOKEN": True, - - # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 - "https://google.com TOKEN": True, - "https://google.com something else": False, - } - - for message, match in messages.items(): - with self.subTest(input=message, match=match): - result, _ = await self.cog._has_watch_regex_match(message) - - self.assertEqual( - match, - bool(result), - msg=f"Hit was {'expected' if match else 'not expected'} for this input." - ) - if result: - self.assertEqual("TOKEN", result.group()) diff --git a/tests/bot/exts/filters/test_security.py b/tests/bot/exts/filters/test_security.py deleted file mode 100644 index 007b7b1eb..000000000 --- a/tests/bot/exts/filters/test_security.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest - -from discord.ext.commands import NoPrivateMessage - -from bot.exts.filters import security -from tests.helpers import MockBot, MockContext - - -class SecurityCogTests(unittest.TestCase): - """Tests the `Security` cog.""" - - def setUp(self): - """Attach an instance of the cog to the class for tests.""" - self.bot = MockBot() - self.cog = security.Security(self.bot) - self.ctx = MockContext() - - def test_check_additions(self): - """The cog should add its checks after initialization.""" - self.bot.check.assert_any_call(self.cog.check_on_guild) - self.bot.check.assert_any_call(self.cog.check_not_bot) - - def test_check_not_bot_returns_false_for_humans(self): - """The bot check should return `True` when invoked with human authors.""" - self.ctx.author.bot = False - self.assertTrue(self.cog.check_not_bot(self.ctx)) - - def test_check_not_bot_returns_true_for_robots(self): - """The bot check should return `False` when invoked with robotic authors.""" - self.ctx.author.bot = True - self.assertFalse(self.cog.check_not_bot(self.ctx)) - - def test_check_on_guild_raises_when_outside_of_guild(self): - """When invoked outside of a guild, `check_on_guild` should cause an error.""" - self.ctx.guild = None - - with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): - self.cog.check_on_guild(self.ctx) - - def test_check_on_guild_returns_true_inside_of_guild(self): - """When invoked inside of a guild, `check_on_guild` should return `True`.""" - self.ctx.guild = "lemon's lemonade stand" - self.assertTrue(self.cog.check_on_guild(self.ctx)) - - -class SecurityCogLoadTests(unittest.IsolatedAsyncioTestCase): - """Tests loading the `Security` cog.""" - - async def test_security_cog_load(self): - """Setup of the extension should call add_cog.""" - bot = MockBot() - await security.setup(bot) - bot.add_cog.assert_awaited_once() diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py deleted file mode 100644 index c1f3762ac..000000000 --- a/tests/bot/exts/filters/test_token_remover.py +++ /dev/null @@ -1,409 +0,0 @@ -import unittest -from re import Match -from unittest import mock -from unittest.mock import MagicMock - -from discord import Colour, NotFound - -from bot import constants -from bot.exts.filters import token_remover -from bot.exts.filters.token_remover import Token, TokenRemover -from bot.exts.moderation.modlog import ModLog -from bot.utils.messages import format_user -from tests.helpers import MockBot, MockMessage, autospec - - -class TokenRemoverTests(unittest.IsolatedAsyncioTestCase): - """Tests the `TokenRemover` cog.""" - - def setUp(self): - """Adds the cog, a bot, and a message to the instance for usage in tests.""" - self.bot = MockBot() - self.cog = TokenRemover(bot=self.bot) - - self.msg = MockMessage(id=555, content="hello world") - self.msg.channel.mention = "#lemonade-stand" - self.msg.guild.get_member.return_value.bot = False - self.msg.guild.get_member.return_value.__str__.return_value = "Woody" - self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) - self.msg.author.display_avatar.url = "picture-lemon.png" - - def test_extract_user_id_valid(self): - """Should consider user IDs valid if they decode into an integer ID.""" - id_pairs = ( - ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), - ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), - ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), - ) - - for token_id, user_id in id_pairs: - with self.subTest(token_id=token_id): - result = TokenRemover.extract_user_id(token_id) - self.assertEqual(result, user_id) - - def test_extract_user_id_invalid(self): - """Should consider non-digit and non-ASCII IDs invalid.""" - ids = ( - ("SGVsbG8gd29ybGQ", "non-digit ASCII"), - ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), - ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), - ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), - ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for user_id, msg in ids: - with self.subTest(msg=msg): - result = TokenRemover.extract_user_id(user_id) - self.assertIsNone(result) - - def test_is_valid_timestamp_valid(self): - """Should consider timestamps valid if they're greater than the Discord epoch.""" - timestamps = ( - "XsyRkw", - "Xrim9Q", - "XsyR-w", - "XsySD_", - "Dn9r_A", - ) - - for timestamp in timestamps: - with self.subTest(timestamp=timestamp): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertTrue(result) - - def test_is_valid_timestamp_invalid(self): - """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" - timestamps = ( - ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), - ("ew", "123"), - ("AoIKgA", "42076800"), - ("{hello}[world]&(bye!)", "ASCII invalid Base64"), - ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), - ) - - for timestamp, msg in timestamps: - with self.subTest(msg=msg): - result = TokenRemover.is_valid_timestamp(timestamp) - self.assertFalse(result) - - def test_is_valid_hmac_valid(self): - """Should consider an HMAC valid if it has at least 3 unique characters.""" - valid_hmacs = ( - "VXmErH7j511turNpfURmb0rVNm8", - "Ysnu2wacjaKs7qnoo46S8Dm2us8", - "sJf6omBPORBPju3WJEIAcwW9Zds", - "s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for hmac in valid_hmacs: - with self.subTest(msg=hmac): - result = TokenRemover.is_maybe_valid_hmac(hmac) - self.assertTrue(result) - - def test_is_invalid_hmac_invalid(self): - """Should consider an HMAC invalid if has fewer than 3 unique characters.""" - invalid_hmacs = ( - ("xxxxxxxxxxxxxxxxxx", "Single character"), - ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), - ("ASFasfASFasfASFASsf", "Three characters alternating-case"), - ("asdasdasdasdasdasdasd", "Three characters one case"), - ) - - for hmac, msg in invalid_hmacs: - with self.subTest(msg=msg): - result = TokenRemover.is_maybe_valid_hmac(hmac) - self.assertFalse(result) - - def test_mod_log_property(self): - """The `mod_log` property should ask the bot to return the `ModLog` cog.""" - self.bot.get_cog.return_value = 'lemon' - self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) - self.bot.get_cog.assert_called_once_with('ModLog') - - async def test_on_message_edit_uses_on_message(self): - """The edit listener should delegate handling of the message to the normal listener.""" - self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True) - - await self.cog.on_message_edit(MockMessage(), self.msg) - self.cog.on_message.assert_awaited_once_with(self.msg) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_takes_action(self, find_token_in_message, take_action): - """Should take action if a valid token is found when a message is sent.""" - cog = TokenRemover(self.bot) - found_token = "foobar" - find_token_in_message.return_value = found_token - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_awaited_once_with(cog, self.msg, found_token) - - @autospec(TokenRemover, "find_token_in_message", "take_action") - async def test_on_message_skips_missing_token(self, find_token_in_message, take_action): - """Shouldn't take action if a valid token isn't found when a message is sent.""" - cog = TokenRemover(self.bot) - find_token_in_message.return_value = False - - await cog.on_message(self.msg) - - find_token_in_message.assert_called_once_with(self.msg) - take_action.assert_not_awaited() - - @autospec(TokenRemover, "find_token_in_message") - async def test_on_message_ignores_dms_bots(self, find_token_in_message): - """Shouldn't parse a message if it is a DM or authored by a bot.""" - cog = TokenRemover(self.bot) - dm_msg = MockMessage(guild=None) - bot_msg = MockMessage(author=MagicMock(bot=True)) - - for msg in (dm_msg, bot_msg): - await cog.on_message(msg) - find_token_in_message.assert_not_called() - - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_no_matches(self, token_re): - """None should be returned if the regex matches no tokens in a message.""" - token_re.finditer.return_value = () - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") - @autospec("bot.exts.filters.token_remover", "Token") - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_valid_match( - self, - token_re, - token_cls, - extract_user_id, - is_valid_timestamp, - is_maybe_valid_hmac, - ): - """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" - matches = [ - mock.create_autospec(Match, spec_set=True, instance=True), - mock.create_autospec(Match, spec_set=True, instance=True), - ] - tokens = [ - mock.create_autospec(Token, spec_set=True, instance=True), - mock.create_autospec(Token, spec_set=True, instance=True), - ] - - token_re.finditer.return_value = matches - token_cls.side_effect = tokens - extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid. - is_valid_timestamp.return_value = True - is_maybe_valid_hmac.return_value = True - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertEqual(tokens[1], return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") - @autospec("bot.exts.filters.token_remover", "Token") - @autospec("bot.exts.filters.token_remover", "TOKEN_RE") - def test_find_token_invalid_matches( - self, - token_re, - token_cls, - extract_user_id, - is_valid_timestamp, - is_maybe_valid_hmac, - ): - """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" - token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] - token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) - extract_user_id.return_value = None - is_valid_timestamp.return_value = False - is_maybe_valid_hmac.return_value = False - - return_value = TokenRemover.find_token_in_message(self.msg) - - self.assertIsNone(return_value) - token_re.finditer.assert_called_once_with(self.msg.content) - - def test_regex_invalid_tokens(self): - """Messages without anything looking like a token are not matched.""" - tokens = ( - "", - "lemon wins", - "..", - "x.y", - "x.y.", - ".y.z", - ".y.", - "..z", - "x..z", - " . . ", - "\n.\n.\n", - "hellö.world.bye", - "base64.nötbåse64.morebase64", - "19jd3J.dfkm3d.€víł§tüff", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.findall(token) - self.assertEqual(len(results), 0) - - def test_regex_valid_tokens(self): - """Messages that look like tokens should be matched.""" - # Don't worry, these tokens have been invalidated. - tokens = ( - "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", - "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", - "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", - "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", - ) - - for token in tokens: - with self.subTest(token=token): - results = token_remover.TOKEN_RE.fullmatch(token) - self.assertIsNotNone(results, f"{token} was not matched by the regex") - - def test_regex_matches_multiple_valid(self): - """Should support multiple matches in the middle of a string.""" - token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" - token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" - message = f"garbage {token_1} hello {token_2} world" - - results = token_remover.TOKEN_RE.finditer(message) - results = [match[0] for match in results] - self.assertCountEqual((token_1, token_2), results) - - @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE") - def test_format_log_message(self, log_message): - """Should correctly format the log message with info from the message and token.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - log_message.format.return_value = "Howdy" - - return_value = TokenRemover.format_log_message(self.msg, token) - - self.assertEqual(return_value, log_message.format.return_value) - log_message.format.assert_called_once_with( - author=format_user(self.msg.author), - channel=self.msg.channel.mention, - user_id=token.user_id, - timestamp=token.timestamp, - hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4", - ) - - @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE") - async def test_format_userid_log_message_unknown(self, unknown_user_log_message,): - """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - unknown_user_log_message.format.return_value = " Partner" - msg = MockMessage(id=555, content="hello world") - msg.guild.get_member.return_value = None - msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found") - - return_value = await TokenRemover.format_userid_log_message(msg, token) - - self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) - unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332) - - @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") - async def test_format_userid_log_message_bot(self, known_user_log_message): - """Should correctly format the user ID portion when the ID belongs to a known bot.""" - token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - known_user_log_message.format.return_value = " Partner" - msg = MockMessage(id=555, content="hello world") - msg.guild.get_member.return_value.__str__.return_value = "Sam" - msg.guild.get_member.return_value.bot = True - - return_value = await TokenRemover.format_userid_log_message(msg, token) - - self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) - - known_user_log_message.format.assert_called_once_with( - user_id=472265943062413332, - user_name="Sam", - kind="BOT", - ) - - @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE") - async def test_format_log_message_user_token_user(self, user_token_message): - """Should correctly format the user ID portion when the ID belongs to a known user.""" - token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") - user_token_message.format.return_value = "Partner" - - return_value = await TokenRemover.format_userid_log_message(self.msg, token) - - self.assertEqual(return_value, (user_token_message.format.return_value, True)) - user_token_message.format.assert_called_once_with( - user_id=467223230650777641, - user_name="Woody", - kind="USER", - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - @autospec("bot.exts.filters.token_remover", "log") - @autospec(TokenRemover, "format_log_message", "format_userid_log_message") - async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property): - """Should delete the message and send a mod log.""" - cog = TokenRemover(self.bot) - mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True) - token = mock.create_autospec(Token, spec_set=True, instance=True) - token.user_id = "no-id" - log_msg = "testing123" - userid_log_message = "userid-log-message" - - mod_log_property.return_value = mod_log - format_log_message.return_value = log_msg - format_userid_log_message.return_value = (userid_log_message, True) - - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_called_once_with( - token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention) - ) - - format_log_message.assert_called_once_with(self.msg, token) - format_userid_log_message.assert_called_once_with(self.msg, token) - logger.debug.assert_called_with(log_msg) - self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens") - - mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id) - mod_log.send_log_message.assert_called_once_with( - icon_url=constants.Icons.token_removed, - colour=Colour(constants.Colours.soft_red), - title="Token removed!", - text=log_msg + "\n" + userid_log_message, - thumbnail=self.msg.author.display_avatar.url, - channel_id=constants.Channels.mod_alerts, - ping_everyone=True, - ) - - @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock) - async def test_take_action_delete_failure(self, mod_log_property): - """Shouldn't send any messages if the token message can't be deleted.""" - cog = TokenRemover(self.bot) - mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True) - self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock()) - - token = mock.create_autospec(Token, spec_set=True, instance=True) - await cog.take_action(self.msg, token) - - self.msg.delete.assert_called_once_with() - self.msg.channel.send.assert_not_awaited() - - -class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase): - """Tests for the token_remover extension.""" - - @autospec("bot.exts.filters.token_remover", "TokenRemover") - async def test_extension_setup(self, cog): - """The TokenRemover cog should be added.""" - bot = MockBot() - await token_remover.setup(bot) - - cog.assert_called_once_with(bot) - bot.add_cog.assert_awaited_once() - self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover)) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py deleted file mode 100644 index 0d570f5a3..000000000 --- a/tests/bot/rules/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest -from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple - -from tests.helpers import MockMessage - - -class DisallowedCase(NamedTuple): - """Encapsulation for test cases expected to fail.""" - recent_messages: List[MockMessage] - culprits: Iterable[str] - n_violations: int - - -class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): - """ - Abstract class for antispam rule test cases. - - Tests for specific rules should inherit from `RuleTest` and implement - `relevant_messages` and `get_report`. Each instance should also set the - `apply` and `config` attributes as necessary. - - The execution of test cases can then be delegated to the `run_allowed` - and `run_disallowed` methods. - """ - - apply: Callable # The tested rule's apply function - config: Dict[str, int] - - async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: - """Run all `cases` against `self.apply` expecting them to pass.""" - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config, - ): - self.assertIsNone( - await self.apply(last_message, recent_messages, self.config) - ) - - async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: - """Run all `cases` against `self.apply` expecting them to fail.""" - for case in cases: - recent_messages, culprits, n_violations = case - last_message = recent_messages[0] - relevant_messages = self.relevant_messages(case) - desired_output = ( - self.get_report(case), - culprits, - relevant_messages, - ) - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - n_violations=n_violations, - config=self.config, - ): - self.assertTupleEqual( - await self.apply(last_message, recent_messages, self.config), - desired_output, - ) - - @abstractmethod - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - """Give expected relevant messages for `case`.""" - raise NotImplementedError # pragma: no cover - - @abstractmethod - def get_report(self, case: DisallowedCase) -> str: - """Give expected error report for `case`.""" - raise NotImplementedError # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py deleted file mode 100644 index d7e779221..000000000 --- a/tests/bot/rules/test_attachments.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Iterable - -from bot.rules import attachments -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_attachments: int) -> MockMessage: - """Builds a message with `total_attachments` attachments.""" - return MockMessage(author=author, attachments=list(range(total_attachments))) - - -class AttachmentRuleTests(RuleTest): - """Tests applying the `attachments` antispam rule.""" - - def setUp(self): - self.apply = attachments.apply - self.config = {"max": 5, "interval": 10} - - async def test_allows_messages_without_too_many_attachments(self): - """Messages without too many attachments are allowed as-is.""" - cases = ( - [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], - [make_msg("bob", 2), make_msg("bob", 2)], - [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_with_too_many_attachments(self): - """Messages with too many attachments trigger the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], - ("bob",), - 10, - ), - DisallowedCase( - [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], - ("bob",), - 6, - ), - DisallowedCase( - [make_msg("alice", 6)], - ("alice",), - 6, - ), - DisallowedCase( - [make_msg("alice", 1) for _ in range(6)], - ("alice",), - 6, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if ( - msg.author == last_message.author - and len(msg.attachments) > 0 - ) - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py deleted file mode 100644 index 03682966b..000000000 --- a/tests/bot/rules/test_burst.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Iterable - -from bot.rules import burst -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: - """ - Init a MockMessage instance with author set to `author`. - - This serves as a shorthand / alias to keep the test cases visually clean. - """ - return MockMessage(author=author) - - -class BurstRuleTests(RuleTest): - """Tests the `burst` antispam rule.""" - - def setUp(self): - self.apply = burst.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("bob"), make_msg("bob")], - [make_msg("bob"), make_msg("alice"), make_msg("bob")], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the amount of messages exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("bob")], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], - ("bob",), - 3, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py deleted file mode 100644 index 3275143d5..000000000 --- a/tests/bot/rules/test_burst_shared.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable - -from bot.rules import burst_shared -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str) -> MockMessage: - """ - Init a MockMessage instance with the passed arg. - - This serves as a shorthand / alias to keep the test cases visually clean. - """ - return MockMessage(author=author) - - -class BurstSharedRuleTests(RuleTest): - """Tests the `burst_shared` antispam rule.""" - - def setUp(self): - self.apply = burst_shared.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """ - Cases that do not violate the rule. - - There really isn't more to test here than a single case. - """ - cases = ( - [make_msg("spongebob"), make_msg("patrick")], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the amount of messages exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("bob")], - {"bob"}, - 3, - ), - DisallowedCase( - [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], - {"bob", "alice"}, - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return case.recent_messages - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py deleted file mode 100644 index f1e3c76a7..000000000 --- a/tests/bot/rules/test_chars.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import chars -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_chars: int) -> MockMessage: - """Build a message with arbitrary content of `n_chars` length.""" - return MockMessage(author=author, content="A" * n_chars) - - -class CharsRuleTests(RuleTest): - """Tests the `chars` antispam rule.""" - - def setUp(self): - self.apply = chars.apply - self.config = { - "max": 20, # Max allowed sum of chars per user - "interval": 10, - } - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of chars within limit.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 20)], - [make_msg("bob", 15), make_msg("alice", 15)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases where the total amount of chars exceeds the limit, triggering the rule.""" - cases = ( - DisallowedCase( - [make_msg("bob", 21)], - ("bob",), - 21, - ), - DisallowedCase( - [make_msg("bob", 15), make_msg("bob", 15)], - ("bob",), - 30, - ), - DisallowedCase( - [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], - ("alice",), - 30, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py deleted file mode 100644 index 66c2d9f92..000000000 --- a/tests/bot/rules/test_discord_emojis.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Iterable - -from bot.rules import discord_emojis -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - -discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> -unicode_emoji = "🧪" - - -def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage: - """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" - return MockMessage(author=author, content=emoji * n_emojis) - - -class DiscordEmojisRuleTests(RuleTest): - """Tests for the `discord_emojis` antispam rule.""" - - def setUp(self): - self.apply = discord_emojis.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of discord and unicode emojis within limit.""" - cases = ( - [make_msg("bob", 2)], - [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], - [make_msg("bob", 2, unicode_emoji)], - [ - make_msg("alice", 1, unicode_emoji), - make_msg("bob", 2, unicode_emoji), - make_msg("alice", 1, unicode_emoji) - ], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with more than the allowed amount of discord and unicode emojis.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], - ("alice",), - 4, - ), - DisallowedCase( - [make_msg("bob", 3, unicode_emoji)], - ("bob",), - 3, - ), - DisallowedCase( - [ - make_msg("alice", 2, unicode_emoji), - make_msg("bob", 2, unicode_emoji), - make_msg("alice", 2, unicode_emoji) - ], - ("alice",), - 4 - ) - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py deleted file mode 100644 index 9bd886a77..000000000 --- a/tests/bot/rules/test_duplicates.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Iterable - -from bot.rules import duplicates -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, content: str) -> MockMessage: - """Give a MockMessage instance with `author` and `content` attrs.""" - return MockMessage(author=author, content=content) - - -class DuplicatesRuleTests(RuleTest): - """Tests the `duplicates` antispam rule.""" - - def setUp(self): - self.apply = duplicates.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("alice", "A"), make_msg("alice", "A")], - [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate - [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with too many duplicate messages from the same author.""" - cases = ( - DisallowedCase( - [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], - ("alice",), - 3, - ), - DisallowedCase( - [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], - ("bob",), - 3, # 4 duplicate messages, but only 3 from bob - ), - DisallowedCase( - [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], - ("bob",), - 3, # 4 message from bob, but only 3 duplicates - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if ( - msg.author == last_message.author - and msg.content == last_message.content - ) - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py deleted file mode 100644 index b091bd9d7..000000000 --- a/tests/bot/rules/test_links.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Iterable - -from bot.rules import links -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, total_links: int) -> MockMessage: - """Makes a message with `total_links` links.""" - content = " ".join(["https://pydis.com"] * total_links) - return MockMessage(author=author, content=content) - - -class LinksTests(RuleTest): - """Tests applying the `links` rule.""" - - def setUp(self): - self.apply = links.apply - self.config = { - "max": 2, - "interval": 10 - } - - async def test_links_within_limit(self): - """Messages with an allowed amount of links.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 2)], - [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 - [make_msg("bob", 1), make_msg("bob", 1)], - [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count - ) - - await self.run_allowed(cases) - - async def test_links_exceeding_limit(self): - """Messages with a a higher than allowed amount of links.""" - cases = ( - DisallowedCase( - [make_msg("bob", 1), make_msg("bob", 2)], - ("bob",), - 3 - ), - DisallowedCase( - [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], - ("alice",), - 3 - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], - ("alice",), - 3 - ) - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py deleted file mode 100644 index f8805ac48..000000000 --- a/tests/bot/rules/test_mentions.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Iterable - -from bot.rules import mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage - - -def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage: - """Makes a message with `total_mentions` mentions.""" - user_mentions = [MockMember() for _ in range(total_user_mentions)] - bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - return MockMessage(author=author, mentions=user_mentions+bot_mentions) - - -class TestMentions(RuleTest): - """Tests applying the `mentions` antispam rule.""" - - def setUp(self): - self.apply = mentions.apply - self.config = { - "max": 2, - "interval": 10, - } - - async def test_mentions_within_limit(self): - """Messages with an allowed amount of mentions.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 2)], - [make_msg("bob", 1), make_msg("bob", 1)], - [make_msg("bob", 1), make_msg("alice", 2)], - ) - - await self.run_allowed(cases) - - async def test_mentions_exceeding_limit(self): - """Messages with a higher than allowed amount of mentions.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], - ("alice",), - 3, - ), - DisallowedCase( - [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], - ("bob",), - 4, - ), - DisallowedCase( - [make_msg("bob", 3, 1)], - ("bob",), - 3, - ), - ) - - await self.run_disallowed(cases) - - async def test_ignore_bot_mentions(self): - """Messages with an allowed amount of mentions, also containing bot mentions.""" - cases = ( - [make_msg("bob", 0, 3)], - [make_msg("bob", 2, 1)], - [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], - [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] - ) - - await self.run_allowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py deleted file mode 100644 index e35377773..000000000 --- a/tests/bot/rules/test_newlines.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Iterable, List - -from bot.rules import newlines -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, newline_groups: List[int]) -> MockMessage: - """Init a MockMessage instance with `author` and content configured by `newline_groups". - - Configure content by passing a list of ints, where each int `n` will generate - a separate group of `n` newlines. - - Example: - newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" - """ - content = " ".join("\n" * n for n in newline_groups) - return MockMessage(author=author, content=content) - - -class TotalNewlinesRuleTests(RuleTest): - """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" - - def setUp(self): - self.apply = newlines.apply - self.config = { - "max": 5, # Max sum of newlines in relevant messages - "max_consecutive": 3, # Max newlines in one group, in one message - "interval": 10, - } - - async def test_allows_messages_within_limit(self): - """Cases which do not violate the rule.""" - cases = ( - [make_msg("alice", [])], # Single message with no newlines - [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages - [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author - [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_total(self): - """Cases which violate the rule by having too many newlines in total.""" - cases = ( - DisallowedCase( # Alice sends a total of 6 newlines (disallowed) - [make_msg("alice", [2, 2]), make_msg("alice", [2])], - ("alice",), - 6, - ), - DisallowedCase( # Here we test that only alice's newlines count in the sum - [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], - ("alice",), - 7, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_author = case.recent_messages[0].author - return tuple(msg for msg in case.recent_messages if msg.author == last_author) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} newlines in {self.config['interval']}s" - - -class GroupNewlinesRuleTests(RuleTest): - """ - Tests the `newlines` antispam rule against max consecutive newline violations. - - As these violations yield a different error report, they require a different - `get_report` implementation. - """ - - def setUp(self): - self.apply = newlines.apply - self.config = {"max": 5, "max_consecutive": 3, "interval": 10} - - async def test_disallows_messages_consecutive(self): - """Cases which violate the rule due to having too many consecutive newlines.""" - cases = ( - DisallowedCase( # Bob sends a group of newlines too large - [make_msg("bob", [4])], - ("bob",), - 4, - ), - DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed) - [make_msg("alice", [1]), make_msg("alice", [4])], - ("alice",), - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_author = case.recent_messages[0].author - return tuple(msg for msg in case.recent_messages if msg.author == last_author) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py deleted file mode 100644 index 26c05d527..000000000 --- a/tests/bot/rules/test_role_mentions.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Iterable - -from bot.rules import role_mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage - - -def make_msg(author: str, n_mentions: int) -> MockMessage: - """Build a MockMessage instance with `n_mentions` role mentions.""" - return MockMessage(author=author, role_mentions=[None] * n_mentions) - - -class RoleMentionsRuleTests(RuleTest): - """Tests for the `role_mentions` antispam rule.""" - - def setUp(self): - self.apply = role_mentions.apply - self.config = {"max": 2, "interval": 10} - - async def test_allows_messages_within_limit(self): - """Cases with a total amount of role mentions within limit.""" - cases = ( - [make_msg("bob", 2)], - [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], - ) - - await self.run_allowed(cases) - - async def test_disallows_messages_beyond_limit(self): - """Cases with more than the allowed amount of role mentions.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], - ("alice",), - 4, - ), - ) - - await self.run_disallowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} role mentions in {self.config['interval']}s" -- cgit v1.2.3 From 8095800ae8f38928ab8c406e622ec79ea93b21c3 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Thu, 9 Dec 2021 00:30:41 +0200 Subject: New filtering backbone and regex filtering migration This commit provides the basis of the new filtering system: - The filtering cog consists of several filter lists loaded from the database (filtering.py). - Each filter list contains a list of filters, which are run in response to events (message posting, reaction, thread creation). Each filter list may choose to respond to different events (the subscribe method in filtering.py). - Each filter has settings (settings.py) which decide when it is going to be run (e.g it might be disabled in a specific channel), and what will happen if it triggers (e.g delete the offending message). - Not every filter has a value for every setting (the _settings_types package) . It will use the default settings specified by its filter list as a fallback. - Since each filter might have a different effect when triggered, we must check all relevant filters even if we found a triggered filter already, unlike in the old system. - Two triggered filters may specify different values for the same setting, therefore each entry has a rule for combining two different values (the __or__ method in each file in _settings_types). To avoid having to prefix each file with an underscore (or the bot will try to load it as a cog), the loading script was changed to ignore packages with names starting with an underscore. Alert sending is done via a webhook so that several embeds can be sent in the same message (will be useful for example for guild invite alerts). Filter lists and setting entries classes are loaded dynamically from their respective packages. In order to be able to test the new features, this commit also includes a migration of the regex-based filtering. --- bot/constants.py | 1 + bot/exts/filtering/README.md | 0 bot/exts/filtering/__init__.py | 0 bot/exts/filtering/_filter_context.py | 39 +++ bot/exts/filtering/_filter_lists/__init__.py | 9 + bot/exts/filtering/_filter_lists/filter_list.py | 79 ++++++ bot/exts/filtering/_filter_lists/token.py | 45 ++++ bot/exts/filtering/_filters/__init__.py | 0 bot/exts/filtering/_filters/filter.py | 29 +++ bot/exts/filtering/_filters/token.py | 20 ++ bot/exts/filtering/_settings.py | 180 ++++++++++++++ bot/exts/filtering/_settings_types/__init__.py | 14 ++ bot/exts/filtering/_settings_types/bypass_roles.py | 29 +++ .../filtering/_settings_types/channel_scope.py | 45 ++++ .../filtering/_settings_types/delete_messages.py | 35 +++ bot/exts/filtering/_settings_types/enabled.py | 18 ++ bot/exts/filtering/_settings_types/filter_dm.py | 18 ++ .../_settings_types/infraction_and_notification.py | 180 ++++++++++++++ bot/exts/filtering/_settings_types/ping.py | 52 ++++ bot/exts/filtering/_settings_types/send_alert.py | 26 ++ .../filtering/_settings_types/settings_entry.py | 85 +++++++ bot/exts/filtering/_utils.py | 97 ++++++++ bot/exts/filtering/filtering.py | 150 ++++++++++++ bot/utils/messages.py | 9 + config-default.yml | 1 + tests/bot/exts/filtering/__init__.py | 0 tests/bot/exts/filtering/test_filters.py | 41 ++++ tests/bot/exts/filtering/test_settings.py | 20 ++ tests/bot/exts/filtering/test_settings_entries.py | 272 +++++++++++++++++++++ 29 files changed, 1494 insertions(+) create mode 100644 bot/exts/filtering/README.md create mode 100644 bot/exts/filtering/__init__.py create mode 100644 bot/exts/filtering/_filter_context.py create mode 100644 bot/exts/filtering/_filter_lists/__init__.py create mode 100644 bot/exts/filtering/_filter_lists/filter_list.py create mode 100644 bot/exts/filtering/_filter_lists/token.py create mode 100644 bot/exts/filtering/_filters/__init__.py create mode 100644 bot/exts/filtering/_filters/filter.py create mode 100644 bot/exts/filtering/_filters/token.py create mode 100644 bot/exts/filtering/_settings.py create mode 100644 bot/exts/filtering/_settings_types/__init__.py create mode 100644 bot/exts/filtering/_settings_types/bypass_roles.py create mode 100644 bot/exts/filtering/_settings_types/channel_scope.py create mode 100644 bot/exts/filtering/_settings_types/delete_messages.py create mode 100644 bot/exts/filtering/_settings_types/enabled.py create mode 100644 bot/exts/filtering/_settings_types/filter_dm.py create mode 100644 bot/exts/filtering/_settings_types/infraction_and_notification.py create mode 100644 bot/exts/filtering/_settings_types/ping.py create mode 100644 bot/exts/filtering/_settings_types/send_alert.py create mode 100644 bot/exts/filtering/_settings_types/settings_entry.py create mode 100644 bot/exts/filtering/_utils.py create mode 100644 bot/exts/filtering/filtering.py create mode 100644 tests/bot/exts/filtering/__init__.py create mode 100644 tests/bot/exts/filtering/test_filters.py create mode 100644 tests/bot/exts/filtering/test_settings.py create mode 100644 tests/bot/exts/filtering/test_settings_entries.py (limited to 'tests') diff --git a/bot/constants.py b/bot/constants.py index c39f9d2b8..65791daa3 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -477,6 +477,7 @@ class Webhooks(metaclass=YAMLGetter): duck_pond: int incidents: int incidents_archive: int + filters: int class Roles(metaclass=YAMLGetter): diff --git a/bot/exts/filtering/README.md b/bot/exts/filtering/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/filtering/__init__.py b/bot/exts/filtering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py new file mode 100644 index 000000000..ee9e87f56 --- /dev/null +++ b/bot/exts/filtering/_filter_context.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from enum import Enum, auto +from typing import Optional, Union + +from discord import DMChannel, Embed, Message, TextChannel, Thread, User + + +class Event(Enum): + """Types of events that can trigger filtering. Note this does not have to align with gateway event types.""" + + MESSAGE = auto() + MESSAGE_EDIT = auto() + + +@dataclass +class FilterContext: + """A dataclass containing the information that should be filtered, and output information of the filtering.""" + + # Input context + event: Event # The type of event + author: User # Who triggered the event + channel: Union[TextChannel, Thread, DMChannel] # The channel involved + content: str # What actually needs filtering + message: Optional[Message] # The message involved + embeds: list = field(default_factory=list) # Any embeds involved + # Output context + dm_content: str = field(default_factory=str) # The content to DM the invoker + dm_embed: Embed = field(default_factory=Embed) # The embed to DM the invoker + send_alert: bool = field(default=True) # Whether to send an alert for the moderators + alert_content: str = field(default_factory=str) # The content of the alert + alert_embeds: list = field(default_factory=list) # Any embeds to add to the alert + action_descriptions: list = field(default_factory=list) # What actions were taken + matches: list = field(default_factory=list) # What exactly was found + + def replace(self, **changes) -> FilterContext: + """Return a new context object assigning new values to the specified fields.""" + return replace(self, **changes) diff --git a/bot/exts/filtering/_filter_lists/__init__.py b/bot/exts/filtering/_filter_lists/__init__.py new file mode 100644 index 000000000..415e3a6bf --- /dev/null +++ b/bot/exts/filtering/_filter_lists/__init__.py @@ -0,0 +1,9 @@ +from os.path import dirname + +from bot.exts.filtering._filter_lists.filter_list import FilterList +from bot.exts.filtering._utils import subclasses_in_package + +filter_list_types = subclasses_in_package(dirname(__file__), f"{__name__}.", FilterList) +filter_list_types = {filter_list.name: filter_list for filter_list in filter_list_types} + +__all__ = [filter_list_types, FilterList] diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py new file mode 100644 index 000000000..f9e304b59 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -0,0 +1,79 @@ +from abc import abstractmethod +from enum import Enum +from typing import Dict, List, Type + +from bot.exts.filtering._settings import Settings, ValidationSettings, create_settings +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._utils import FieldRequiring +from bot.log import get_logger + +log = get_logger(__name__) + + +class ListType(Enum): + DENY = 0 + ALLOW = 1 + + +class FilterList(FieldRequiring): + """Dispatches events to lists of _filters, and aggregates the responses into a single list of actions to take.""" + + # Each subclass must define a name matching the filter_list name we're expecting to receive from the database. + # Names must be unique across all filter lists. + name = FieldRequiring.MUST_SET_UNIQUE + + def __init__(self, filter_type: Type[Filter]): + self._filter_lists: dict[ListType, list[Filter]] = {} + self._defaults: dict[ListType, dict[str, Settings]] = {} + + self.filter_type = filter_type + + def add_list(self, list_data: Dict) -> None: + """Add a new type of list (such as a whitelist or a blacklist) this filter list.""" + actions, validations = create_settings(list_data["settings"]) + list_type = ListType(list_data["list_type"]) + self._defaults[list_type] = {"actions": actions, "validations": validations} + + filters = [] + for filter_data in list_data["filters"]: + try: + filters.append(self.filter_type(filter_data, actions)) + except TypeError as e: + log.warning(e) + self._filter_lists[list_type] = filters + + @abstractmethod + def triggers_for(self, ctx: FilterContext) -> list[Filter]: + """Dispatch the given event to the list's filters, and return filters triggered.""" + + @staticmethod + def filter_list_result(ctx: FilterContext, filters: List[Filter], defaults: ValidationSettings) -> list[Filter]: + """ + Sift through the list of filters, and return only the ones which apply to the given context. + + The strategy is as follows: + 1. The default settings are evaluated on the given context. The default answer for whether the filter is + relevant in the given context is whether there aren't any validation settings which returned False. + 2. For each filter, its overrides are considered: + - If there are no overrides, then the filter is relevant if that is the default answer. + - Otherwise it is relevant if there are no failed overrides, and any failing default is overridden by a + successful override. + + If the filter is relevant in context, see if it actually triggers. + """ + passed_by_default, failed_by_default = defaults.evaluate(ctx) + default_answer = not bool(failed_by_default) + + relevant_filters = [] + for filter_ in filters: + if not filter_.validations: + if default_answer and filter_.triggered_on(ctx): + relevant_filters.append(filter_) + else: + passed, failed = filter_.validations.evaluate(ctx) + if not failed and failed_by_default < passed: + if filter_.triggered_on(ctx): + relevant_filters.append(filter_) + + return relevant_filters diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py new file mode 100644 index 000000000..4495f4414 --- /dev/null +++ b/bot/exts/filtering/_filter_lists/token.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import re +import typing + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filters.token import TokenFilter +from bot.exts.filtering._utils import clean_input + +if typing.TYPE_CHECKING: + from bot.exts.filtering.filtering import Filtering + +SPOILER_RE = re.compile(r"(\|\|.+?\|\|)", re.DOTALL) + + +class TokensList(FilterList): + """A list of filters, each looking for a specific token given by regex.""" + + name = "token" + + def __init__(self, filtering_cog: Filtering): + super().__init__(TokenFilter) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + + def triggers_for(self, ctx: FilterContext) -> list[Filter]: + """Dispatch the given event to the list's filters, and return filters triggered.""" + text = ctx.content + if SPOILER_RE.search(text): + text = self._expand_spoilers(text) + text = clean_input(text) + ctx = ctx.replace(content=text) + + return self.filter_list_result( + ctx, self._filter_lists[ListType.DENY], self._defaults[ListType.DENY]["validations"] + ) + + @staticmethod + def _expand_spoilers(text: str) -> str: + """Return a string containing all interpretations of a spoilered message.""" + split_text = SPOILER_RE.split(text) + return ''.join( + split_text[0::2] + split_text[1::2] + split_text + ) diff --git a/bot/exts/filtering/_filters/__init__.py b/bot/exts/filtering/_filters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py new file mode 100644 index 000000000..484e506fc --- /dev/null +++ b/bot/exts/filtering/_filters/filter.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings import ActionSettings, create_settings + + +class Filter(ABC): + """ + A class representing a filter. + + Each filter looks for a specific attribute within an event (such as message sent), + and defines what action should be performed if it is triggered. + """ + + def __init__(self, filter_data: Dict, action_defaults: Optional[ActionSettings] = None): + self.id = filter_data["id"] + self.content = filter_data["content"] + self.description = filter_data["description"] + self.actions, self.validations = create_settings(filter_data["settings"]) + if not self.actions: + self.actions = action_defaults + elif action_defaults: + self.actions.fallback_to(action_defaults) + self.exact = filter_data["additional_field"] + + @abstractmethod + def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" diff --git a/bot/exts/filtering/_filters/token.py b/bot/exts/filtering/_filters/token.py new file mode 100644 index 000000000..07590c54b --- /dev/null +++ b/bot/exts/filtering/_filters/token.py @@ -0,0 +1,20 @@ +import re + +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._filter_context import FilterContext + + +class TokenFilter(Filter): + """A filter which looks for a specific token given by regex.""" + + def triggered_on(self, ctx: FilterContext) -> bool: + """Searches for a regex pattern within a given context.""" + pattern = self.content + + match = re.search(pattern, ctx.content, flags=re.IGNORECASE) + if match: + ctx.matches.append(match[0]) + return True + return False + + diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py new file mode 100644 index 000000000..96e1c1f7f --- /dev/null +++ b/bot/exts/filtering/_settings.py @@ -0,0 +1,180 @@ +from __future__ import annotations +from abc import abstractmethod +from typing import Iterator, Mapping, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types import settings_types +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry +from bot.exts.filtering._utils import FieldRequiring +from bot.log import get_logger + +log = get_logger(__name__) + +_already_warned: set[str] = set() + + +def create_settings(settings_data: dict) -> tuple[Optional[ActionSettings], Optional[ValidationSettings]]: + """ + Create and return instances of the Settings subclasses from the given data + + Additionally, warn for data entries with no matching class. + """ + action_data = {} + validation_data = {} + for entry_name, entry_data in settings_data.items(): + if entry_name in settings_types["ActionEntry"]: + action_data[entry_name] = entry_data + elif entry_name in settings_types["ValidationEntry"]: + validation_data[entry_name] = entry_data + else: + log.warning( + f"A setting named {entry_name} was loaded from the database, but no matching class." + ) + _already_warned.add(entry_name) + return ActionSettings.create(action_data), ValidationSettings.create(validation_data) + + +class Settings(FieldRequiring): + """ + A collection of settings. + + For processing the settings parts in the database and evaluating them on given contexts. + + Each filter list and filter has its own settings. + + A filter doesn't have to have its own settings. For every undefined setting, it falls back to the value defined in + the filter list which contains the filter. + """ + + entry_type = FieldRequiring.MUST_SET + + _already_warned: set[str] = set() + + @abstractmethod + def __init__(self, settings_data: dict): + self._entries: dict[str, Settings.entry_type] = {} + + entry_classes = settings_types.get(self.entry_type.__name__) + for entry_name, entry_data in settings_data.items(): + try: + entry_cls = entry_classes[entry_name] + except KeyError: + if entry_name not in self._already_warned: + log.warning( + f"A setting named {entry_name} was loaded from the database, " + f"but no matching {self.entry_type.__name__} class." + ) + self._already_warned.add(entry_name) + else: + try: + new_entry = entry_cls.create(entry_data) + if new_entry: + self._entries[entry_name] = new_entry + except TypeError as e: + raise TypeError( + f"Attempted to load a {entry_name} setting, but the response is malformed: {entry_data}" + ) from e + + def __contains__(self, item) -> bool: + return item in self._entries + + def __setitem__(self, key: str, value: entry_type) -> None: + self._entries[key] = value + + def copy(self): + copy = self.__class__({}) + copy._entries = self._entries + return copy + + def items(self) -> Iterator[tuple[str, entry_type]]: + yield from self._entries.items() + + def update(self, mapping: Mapping[str, entry_type], **kwargs: entry_type) -> None: + self._entries.update(mapping, **kwargs) + + @classmethod + def create(cls, settings_data: dict) -> Optional[Settings]: + """ + Returns a Settings object from `settings_data` if it holds any value, None otherwise. + + Use this method to create Settings objects instead of the init. + The None value is significant for how a filter list iterates over its filters. + """ + settings = cls(settings_data) + # If an entry doesn't hold any values, its `create` method will return None. + # If all entries are None, then the settings object holds no values. + if not any(settings._entries.values()): + return None + + return settings + + +class ValidationSettings(Settings): + """ + A collection of validation settings. + + A filter is triggered only if all of its validation settings (e.g whether to invoke in DM) approve + (the check returns True). + """ + + entry_type = ValidationEntry + + def __init__(self, settings_data: dict): + super().__init__(settings_data) + + def evaluate(self, ctx: FilterContext) -> tuple[set[str], set[str]]: + """Evaluates for each setting whether the context is relevant to the filter.""" + passed = set() + failed = set() + + self._entries: dict[str, ValidationEntry] + for name, validation in self._entries.items(): + if validation: + if validation.triggers_on(ctx): + passed.add(name) + else: + failed.add(name) + + return passed, failed + + +class ActionSettings(Settings): + """ + A collection of action settings. + + If a filter is triggered, its action settings (e.g how to infract the user) are combined with the action settings of + other triggered filters in the same event, and action is taken according to the combined action settings. + """ + + entry_type = ActionEntry + + def __init__(self, settings_data: dict): + super().__init__(settings_data) + + def __or__(self, other: ActionSettings) -> ActionSettings: + """Combine the entries of two collections of settings into a new ActionsSettings""" + actions = {} + # A settings object doesn't necessarily have all types of entries (e.g in the case of filter overrides). + for entry in self._entries: + if entry in other._entries: + actions[entry] = self._entries[entry] | other._entries[entry] + else: + actions[entry] = self._entries[entry] + for entry in other._entries: + if entry not in actions: + actions[entry] = other._entries[entry] + + result = ActionSettings({}) + result.update(actions) + return result + + async def action(self, ctx: FilterContext) -> None: + """Execute the action of every action entry stored.""" + for entry in self._entries.values(): + await entry.action(ctx) + + def fallback_to(self, fallback: ActionSettings) -> None: + """Fill in missing entries from `fallback`.""" + for entry_name, entry_value in fallback.items(): + if entry_name not in self._entries: + self._entries[entry_name] = entry_value diff --git a/bot/exts/filtering/_settings_types/__init__.py b/bot/exts/filtering/_settings_types/__init__.py new file mode 100644 index 000000000..620290cb2 --- /dev/null +++ b/bot/exts/filtering/_settings_types/__init__.py @@ -0,0 +1,14 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry +from bot.exts.filtering._utils import subclasses_in_package + +action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry) +validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry) + +settings_types = { + "ActionEntry": {settings_type.name: settings_type for settings_type in action_types}, + "ValidationEntry": {settings_type.name: settings_type for settings_type in validation_types} +} + +__all__ = [settings_types] diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py new file mode 100644 index 000000000..9665283ff --- /dev/null +++ b/bot/exts/filtering/_settings_types/bypass_roles.py @@ -0,0 +1,29 @@ +from typing import Any + +from discord import Member + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry +from bot.exts.filtering._utils import ROLE_LITERALS + + +class RoleBypass(ValidationEntry): + """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" + + name = "bypass_roles" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.roles = set() + for role in entry_data: + if role in ROLE_LITERALS: + self.roles.add(ROLE_LITERALS[role]) + elif role.isdigit(): + self.roles.add(int(role)) + # Ignore entries that can't be resolved. + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered on this user given their roles.""" + if not isinstance(ctx.author, Member): + return True + return all(member_role.id not in self.roles for member_role in ctx.author.roles) diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py new file mode 100644 index 000000000..b17914f2f --- /dev/null +++ b/bot/exts/filtering/_settings_types/channel_scope.py @@ -0,0 +1,45 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class ChannelScope(ValidationEntry): + """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" + + name = "channel_scope" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + if entry_data["disabled_channels"]: + self.disabled_channels = set(entry_data["disabled_channels"]) + else: + self.disabled_channels = set() + + if entry_data["disabled_categories"]: + self.disabled_categories = set(entry_data["disabled_categories"]) + else: + self.disabled_categories = set() + + if entry_data["enabled_channels"]: + self.enabled_channels = set(entry_data["enabled_channels"]) + else: + self.enabled_channels = set() + + def triggers_on(self, ctx: FilterContext) -> bool: + """ + Return whether the filter should be triggered in the given channel. + + The filter is invoked by default. + If the channel is explicitly enabled, it bypasses the set disabled channels and categories. + """ + channel = ctx.channel + if hasattr(channel, "parent"): + channel = channel.parent + return ( + channel.id in self.enabled_channels + or ( + channel.id not in self.disabled_channels + and (not channel.category or channel.category.id not in self.disabled_categories) + ) + ) diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py new file mode 100644 index 000000000..b0a018433 --- /dev/null +++ b/bot/exts/filtering/_settings_types/delete_messages.py @@ -0,0 +1,35 @@ +from contextlib import suppress +from typing import Any + +from discord.errors import NotFound + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class DeleteMessages(ActionEntry): + """A setting entry which tells whether to delete the offending message(s).""" + + name = "delete_messages" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.delete: bool = entry_data + + async def action(self, ctx: FilterContext) -> None: + """Delete the context message(s).""" + if not self.delete or ctx.event not in (Event.MESSAGE, Event.MESSAGE_EDIT): + return + + with suppress(NotFound): + if ctx.message.guild: + await ctx.message.delete() + ctx.action_descriptions.append("deleted") + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, DeleteMessages): + return NotImplemented + + return DeleteMessages(self.delete or other.delete) + diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py new file mode 100644 index 000000000..553dccc9c --- /dev/null +++ b/bot/exts/filtering/_settings_types/enabled.py @@ -0,0 +1,18 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class Enabled(ValidationEntry): + """A setting entry which tells whether the filter is enabled.""" + + name = "enabled" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.enabled = entry_data + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter is enabled.""" + return self.enabled diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py new file mode 100644 index 000000000..54f19e4d1 --- /dev/null +++ b/bot/exts/filtering/_settings_types/filter_dm.py @@ -0,0 +1,18 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class FilterDM(ValidationEntry): + """A setting entry which tells whether to apply the filter to DMs.""" + + name = "filter_dm" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.apply_in_dm = entry_data + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered even if it was triggered in DMs.""" + return hasattr(ctx.channel, "guild") or self.apply_in_dm diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py new file mode 100644 index 000000000..263fd851c --- /dev/null +++ b/bot/exts/filtering/_settings_types/infraction_and_notification.py @@ -0,0 +1,180 @@ +from collections import namedtuple +from datetime import timedelta +from enum import Enum, auto +from typing import Any, Optional + +import arrow +from discord import Colour +from discord.errors import Forbidden + +import bot +from bot.constants import Channels, Guild +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class Infraction(Enum): + """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" + + BAN = auto() + KICK = auto() + MUTE = auto() + VOICE_BAN = auto() + SUPERSTAR = auto() + WARNING = auto() + WATCH = auto() + NOTE = auto() + NONE = auto() # Allows making operations on an entry with no infraction without checking for None. + + def __bool__(self) -> bool: + """ + Make the NONE value false-y. + + This is useful for Settings.create to evaluate whether the entry contains anything. + """ + return self != Infraction.NONE + + +superstar = namedtuple("superstar", ["reason", "duration"]) + + +class InfractionAndNotification(ActionEntry): + """ + A setting entry which specifies what infraction to issue and the notification to DM the user. + + Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. + """ + + name = "infraction_and_notification" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + + if entry_data["infraction_type"]: + self.infraction_type = entry_data["infraction_type"] + if isinstance(self.infraction_type, str): + self.infraction_type = Infraction[self.infraction_type.replace(" ", "_").upper()] + self.infraction_reason = entry_data["infraction_reason"] + if entry_data["infraction_duration"] is not None: + self.infraction_duration = float(entry_data["infraction_duration"]) + else: + self.infraction_duration = None + else: + self.infraction_type = Infraction.NONE + self.infraction_reason = None + self.infraction_duration = 0 + + self.dm_content = entry_data["dm_content"] + self.dm_embed = entry_data["dm_embed"] + + self._superstar = entry_data.get("superstar", None) + + async def action(self, ctx: FilterContext) -> None: + """Send the notification to the user, and apply any specified infractions.""" + # If there is no infraction to apply, any DM contents already provided in the context take precedence. + if self.infraction_type == Infraction.NONE and (ctx.dm_content or ctx.dm_embed): + dm_content = ctx.dm_content + dm_embed = ctx.dm_embed.description + else: + dm_content = self.dm_content + dm_embed = self.dm_embed + + if dm_content or dm_embed: + dm_content = f"Hey {ctx.author.mention}!\n{dm_content}" + ctx.dm_embed.description = dm_embed + if not ctx.dm_embed.colour: + ctx.dm_embed.colour = Colour.og_blurple() + + try: + await ctx.author.send(dm_content, embed=ctx.dm_embed) + except Forbidden: + await ctx.channel.send(ctx.dm_content, embed=ctx.dm_embed) + ctx.action_descriptions.append("notified") + + msg_ctx = await bot.instance.get_context(ctx.message) + msg_ctx.guild = bot.instance.get_guild(Guild.id) + msg_ctx.author = ctx.author + msg_ctx.channel = ctx.channel + if self._superstar: + msg_ctx.command = bot.instance.get_command("superstarify") + await msg_ctx.invoke( + msg_ctx.command, + ctx.author, + arrow.utcnow() + timedelta(seconds=self._superstar.duration) + if self._superstar.duration is not None else None, + reason=self._superstar.reason + ) + ctx.action_descriptions.append("superstar") + + if self.infraction_type != Infraction.NONE: + if self.infraction_type == Infraction.BAN or not hasattr(ctx.channel, "guild"): + msg_ctx.channel = bot.instance.get_channel(Channels.mod_alerts) + msg_ctx.command = bot.instance.get_command(self.infraction_type.name) + await msg_ctx.invoke( + msg_ctx.command, + ctx.author, + arrow.utcnow() + timedelta(seconds=self.infraction_duration) + if self.infraction_duration is not None else None, + reason=self.infraction_reason + ) + ctx.action_descriptions.append(self.infraction_type.name.lower()) + + def __or__(self, other: ActionEntry): + """ + Combines two actions of the same type. Each type of action is executed once per filter. + + If the infractions are different, take the data of the one higher up the hierarchy. + + A special case is made for superstar infractions. Even if we decide to auto-mute a user, if they have a + particularly problematic username we will still want to superstarify them. + + This is a "best attempt" implementation. Trying to account for any type of combination would create an + extremely complex ruleset. For example, we could special-case watches as well. + + There is no clear way to properly combine several notification messages, especially when it's in two parts. + To avoid bombarding the user with several notifications, the message with the more significant infraction + is used. + """ + if not isinstance(other, InfractionAndNotification): + return NotImplemented + + # Lower number -> higher in the hierarchy + if self.infraction_type.value < other.infraction_type.value and other.infraction_type != Infraction.SUPERSTAR: + result = InfractionAndNotification(self.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + elif self.infraction_type.value > other.infraction_type.value and self.infraction_type != Infraction.SUPERSTAR: + result = InfractionAndNotification(other.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + + if self.infraction_type == other.infraction_type: + if self.infraction_duration is None or ( + other.infraction_duration is not None and self.infraction_duration > other.infraction_duration + ): + result = InfractionAndNotification(self.to_dict()) + else: + result = InfractionAndNotification(other.to_dict()) + result._superstar = self._merge_superstars(self._superstar, other._superstar) + return result + + # At this stage the infraction types are different, and the lower one is a superstar. + if self.infraction_type.value < other.infraction_type.value: + result = InfractionAndNotification(self.to_dict()) + result._superstar = superstar(other.infraction_reason, other.infraction_duration) + else: + result = InfractionAndNotification(other.to_dict()) + result._superstar = superstar(self.infraction_reason, self.infraction_duration) + return result + + @staticmethod + def _merge_superstars(superstar1: Optional[superstar], superstar2: Optional[superstar]) -> Optional[superstar]: + """Take the superstar with the greater duration.""" + if not superstar1: + return superstar2 + if not superstar2: + return superstar1 + + if superstar1.duration is None or superstar1.duration > superstar2.duration: + return superstar1 + return superstar2 diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py new file mode 100644 index 000000000..857e4a7e8 --- /dev/null +++ b/bot/exts/filtering/_settings_types/ping.py @@ -0,0 +1,52 @@ +from functools import cache +from typing import Any + +from discord import Guild + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import ROLE_LITERALS + + +class Ping(ActionEntry): + """A setting entry which adds the appropriate pings to the alert.""" + + name = "mentions" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.guild_mentions = set(entry_data["guild_pings"]) + self.dm_mentions = set(entry_data["dm_pings"]) + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + mentions = self.guild_mentions if ctx.channel.guild else self.dm_mentions + new_content = " ".join([self._resolve_mention(mention, ctx.channel.guild) for mention in mentions]) + ctx.alert_content = f"{new_content} {ctx.alert_content}" + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, Ping): + return NotImplemented + + return Ping({ + "ping_type": self.guild_mentions | other.guild_mentions, + "dm_ping_type": self.dm_mentions | other.dm_mentions + }) + + @staticmethod + @cache + def _resolve_mention(mention: str, guild: Guild) -> str: + """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" + if mention in ("here", "everyone"): + return f"@{mention}" + if mention in ROLE_LITERALS: + return f"<@&{ROLE_LITERALS[mention]}>" + if not mention.isdigit(): + return mention + + mention = int(mention) + if any(mention == role.id for role in guild.roles): + return f"<@&{mention}>" + else: + return f"<@{mention}>" diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py new file mode 100644 index 000000000..e332494eb --- /dev/null +++ b/bot/exts/filtering/_settings_types/send_alert.py @@ -0,0 +1,26 @@ +from typing import Any + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class SendAlert(ActionEntry): + """A setting entry which tells whether to send an alert message.""" + + name = "send_alert" + + def __init__(self, entry_data: Any): + super().__init__(entry_data) + self.send_alert: bool = entry_data + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + ctx.send_alert = self.send_alert + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, SendAlert): + return NotImplemented + + return SendAlert(self.send_alert or other.send_alert) + diff --git a/bot/exts/filtering/_settings_types/settings_entry.py b/bot/exts/filtering/_settings_types/settings_entry.py new file mode 100644 index 000000000..b0d54fac3 --- /dev/null +++ b/bot/exts/filtering/_settings_types/settings_entry.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Optional + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._utils import FieldRequiring + + +class SettingsEntry(FieldRequiring): + """ + A basic entry in the settings field appearing in every filter list and filter. + + For a filter list, this is the default setting for it. For a filter, it's an override of the default entry. + """ + + # Each subclass must define a name matching the entry name we're expecting to receive from the database. + # Names must be unique across all filter lists. + name = FieldRequiring.MUST_SET_UNIQUE + + @abstractmethod + def __init__(self, entry_data: Any): + super().__init__() + self._dict = {} + + def __setattr__(self, key: str, value: Any) -> None: + super().__setattr__(key, value) + if key == "_dict": + return + self._dict[key] = value + + def __eq__(self, other: SettingsEntry) -> bool: + if not isinstance(other, SettingsEntry): + return NotImplemented + return self._dict == other._dict + + def to_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the entry.""" + return self._dict.copy() + + def copy(self) -> SettingsEntry: + """Return a new entry object with the same parameters.""" + return self.__class__(self.to_dict()) + + @classmethod + def create(cls, entry_data: Optional[dict[str, Any]]) -> Optional[SettingsEntry]: + """ + Returns a SettingsEntry object from `entry_data` if it holds any value, None otherwise. + + Use this method to create SettingsEntry objects instead of the init. + The None value is significant for how a filter list iterates over its filters. + """ + if entry_data is None: + return None + if hasattr(entry_data, "values") and not any(value for value in entry_data.values()): + return None + + return cls(entry_data) + + +class ValidationEntry(SettingsEntry): + """A setting entry to validate whether the filter should be triggered in the given context.""" + + @abstractmethod + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered with this setting in the given context.""" + ... + + +class ActionEntry(SettingsEntry): + """A setting entry defining what the bot should do if the filter it belongs to is triggered.""" + + @abstractmethod + async def action(self, ctx: FilterContext) -> None: + """Execute an action that should be taken when the filter this setting belongs to is triggered.""" + ... + + @abstractmethod + def __or__(self, other: ActionEntry): + """ + Combine two actions of the same type. Each type of action is executed once per filter. + + The following condition must hold: if self == other, then self | other == self. + """ + ... diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py new file mode 100644 index 000000000..a769001f6 --- /dev/null +++ b/bot/exts/filtering/_utils.py @@ -0,0 +1,97 @@ +import importlib +import importlib.util +import inspect +import pkgutil +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Set + +import regex + +from bot.constants import Roles + +ROLE_LITERALS = { + "admins": Roles.admins, + "onduty": Roles.moderators, + "staff": Roles.helpers +} + +VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF" +INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1) +ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1) + + +def subclasses_in_package(package: str, prefix: str, parent: type) -> Set[type]: + """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" + subclasses = set() + + # Find all modules in the package. + for module_info in pkgutil.iter_modules([package], prefix): + if not module_info.ispkg: + module = importlib.import_module(module_info.name) + # Find all classes in each module... + for _, class_ in inspect.getmembers(module, inspect.isclass): + # That are a subclass of the given class. + if parent in class_.__bases__: + subclasses.add(class_) + + return subclasses + + +def clean_input(string: str) -> str: + """Remove zalgo and invisible characters from `string`.""" + # For future consideration: remove characters in the Mc, Sk, and Lm categories too. + # Can be normalised with form C to merge char + combining char into a single char to avoid + # removing legit diacritics, but this would open up a way to bypass _filters. + no_zalgo = ZALGO_RE.sub("", string) + return INVISIBLE_RE.sub("", no_zalgo) + + +class FieldRequiring(ABC): + """A mixin class that can force its concrete subclasses to set a value for specific class attributes.""" + + # Sentinel value that mustn't remain in a concrete subclass. + MUST_SET = object() + + # Sentinel value that mustn't remain in a concrete subclass. + # Overriding value must be unique in the subclasses of the abstract class in which the attribute was set. + MUST_SET_UNIQUE = object() + + # A mapping of the attributes which must be unique, and their unique values, per FieldRequiring subclass. + __unique_attributes: defaultdict[type, dict[str, set]] = defaultdict(dict) + + @abstractmethod + def __init__(self): + ... + + def __init_subclass__(cls, **kwargs): + # If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it. + if inspect.isabstract(cls): + for attribute in dir(cls): + if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE: + for parent in cls.__mro__[1:-1]: # The first element is the class itself, last element is object. + if hasattr(parent, attribute): # The attribute was inherited. + break + else: + # A new attribute with the value MUST_SET_UNIQUE. + FieldRequiring.__unique_attributes[cls][attribute] = set() + return + + for attribute in dir(cls): + if attribute.startswith("__") or attribute in ("MUST_SET", "MUST_SET_UNIQUE"): + continue + value = getattr(cls, attribute) + if value is FieldRequiring.MUST_SET: + raise ValueError(f"You must set attribute {attribute!r} when creating {cls!r}") + elif value is FieldRequiring.MUST_SET_UNIQUE: + raise ValueError(f"You must set a unique value to attribute {attribute!r} when creating {cls!r}") + else: + # Check if the value needs to be unique. + for parent in cls.__mro__[1:-1]: + # Find the parent class the attribute was first defined in. + if attribute in FieldRequiring.__unique_attributes[parent]: + if value in FieldRequiring.__unique_attributes[parent][attribute]: + raise ValueError(f"Value of {attribute!r} in {cls!r} is not unique for parent {parent!r}.") + else: + # Add to the set of unique values for that field. + FieldRequiring.__unique_attributes[parent][attribute].add(value) diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py new file mode 100644 index 000000000..c74b85698 --- /dev/null +++ b/bot/exts/filtering/filtering.py @@ -0,0 +1,150 @@ +import operator +from collections import defaultdict +from functools import reduce +from typing import Optional + +from discord import Embed, HTTPException, Message +from discord.ext.commands import Cog +from discord.utils import escape_markdown + +from bot.bot import Bot +from bot.constants import Colours, Webhooks +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import FilterList, filter_list_types +from bot.exts.filtering._filters.filter import Filter +from bot.exts.filtering._settings import ActionSettings +from bot.log import get_logger +from bot.utils.messages import format_channel, format_user + +log = get_logger(__name__) + + +class Filtering(Cog): + """Filtering and alerting for content posted on the server.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.filter_lists: dict[str, FilterList] = {} + self._subscriptions: defaultdict[Event, list[FilterList]] = defaultdict(list) + self.webhook = None + + async def cog_load(self) -> None: + """ + Fetch the filter data from the API, parse it, and load it to the appropriate data structures. + + Additionally, fetch the alerting webhook. + """ + await self.bot.wait_until_guild_available() + already_warned = set() + + raw_filter_lists = await self.bot.api_client.get("bot/filter/filter_lists") + for raw_filter_list in raw_filter_lists: + list_name = raw_filter_list["name"] + if list_name not in self.filter_lists: + if list_name not in filter_list_types: + if list_name not in already_warned: + log.warning( + f"A filter list named {list_name} was loaded from the database, but no matching class." + ) + already_warned.add(list_name) + continue + self.filter_lists[list_name] = filter_list_types[list_name](self) + self.filter_lists[list_name].add_list(raw_filter_list) + + try: + self.webhook = await self.bot.fetch_webhook(Webhooks.filters) + except HTTPException: + log.error(f"Failed to fetch incidents webhook with id `{Webhooks.incidents}`.") + + def subscribe(self, filter_list: FilterList, *events: Event) -> None: + """ + Subscribe a filter list to the given events. + + The filter list is added to a list for each event. When the event is triggered, the filter context will be + dispatched to the subscribed filter lists. + + While it's possible to just make each filter list check the context's event, these are only the events a filter + list expects to receive from the filtering cog, there isn't an actual limitation on the kinds of events a filter + list can handle as long as the filter context is built properly. If for whatever reason we want to invoke a + filter list outside of the usual procedure with the filtering cog, it will be more problematic if the events are + hard-coded into each filter list. + """ + for event in events: + if filter_list not in self._subscriptions[event]: + self._subscriptions[event].append(filter_list) + + async def _resolve_action( + self, ctx: FilterContext + ) -> tuple[dict[FilterList, list[Filter]], Optional[ActionSettings]]: + """Get the filters triggered per list, and resolve from them the action that needs to be taken for the event.""" + triggered = {} + for filter_list in self._subscriptions[ctx.event]: + triggered[filter_list] = filter_list.triggers_for(ctx) + + result_actions = None + if triggered: + result_actions = reduce( + operator.or_, (filter_.actions for filters in triggered.values() for filter_ in filters) + ) + + return triggered, result_actions + + @Cog.listener() + async def on_message(self, msg: Message) -> None: + """Filter the contents of a sent message.""" + if msg.author.bot: + return + + ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) + + triggered, result_actions = await self._resolve_action(ctx) + if result_actions: + await result_actions.action(ctx) + if ctx.send_alert: + await self._send_alert(ctx, triggered) + + async def _send_alert(self, ctx: FilterContext, triggered_filters: dict[FilterList, list[Filter]]) -> None: + """Build an alert message from the filter context, and send it via the alert webhook.""" + if not self.webhook: + return + + name = f"{ctx.event.name.replace('_', ' ').title()} Filter" + + embed = Embed(color=Colours.soft_orange) + embed.set_thumbnail(url=ctx.author.display_avatar.url) + triggered_by = f"**Triggered by:** {format_user(ctx.author)}" + if ctx.channel.guild: + triggered_in = f"**Triggered in:** {format_channel(ctx.channel)}" + else: + triggered_in = "**DM**" + if len(triggered_filters) == 1 and len(list(triggered_filters.values())[0]) == 1: + filter_list, (filter_,) = next(iter(triggered_filters.items())) + filters = f"**{filter_list.name.title()} Filter:** #{filter_.id} (`{filter_.content}`)" + if filter_.description: + filters += f" - {filter_.description}" + else: + filters = [] + for filter_list, list_filters in triggered_filters.items(): + filters.append( + (f"**{filter_list.name.title()} Filters:** " + ", ".join(f"#{filter_.id} (`{filter_.content}`)" for filter_ in list_filters)) + ) + filters = "\n".join(filters) + + matches = "**Matches:** " + ", ".join(repr(match) for match in ctx.matches) + actions = "**Actions Taken:** " + (", ".join(ctx.action_descriptions) if ctx.action_descriptions else "-") + content = f"**[Original Content]({ctx.message.jump_url})**: {escape_markdown(ctx.content)}" + + embed_content = "\n".join( + part for part in (triggered_by, triggered_in, filters, matches, actions, content) if part + ) + if len(embed_content) > 4000: + embed_content = embed_content[:4000] + " [...]" + embed.description = embed_content + + await self.webhook.send(username=name, content=ctx.alert_content, embeds=[embed, *ctx.alert_embeds]) + + +async def setup(bot: Bot) -> None: + """Load the Filtering cog.""" + await bot.add_cog(Filtering(bot)) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index a5ed84351..63929cd0b 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -238,3 +238,12 @@ async def send_denial(ctx: Context, reason: str) -> discord.Message: def format_user(user: discord.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" + + +def format_channel(channel: discord.abc.Messageable) -> str: + """Return a string for `channel` with its mention, ID, and the parent channel if it is a thread.""" + formatted = f"{channel.mention} ({channel.category}/#{channel}" + if hasattr(channel, "parent"): + formatted += f"/{channel.parent}" + formatted += ")" + return formatted diff --git a/config-default.yml b/config-default.yml index 91945e2b8..1815b8ed7 100644 --- a/config-default.yml +++ b/config-default.yml @@ -317,6 +317,7 @@ guild: incidents: 816650601844572212 incidents_archive: 720671599790915702 python_news: &PYNEWS_WEBHOOK 704381182279942324 + filters: 926442964463521843 filter: diff --git a/tests/bot/exts/filtering/__init__.py b/tests/bot/exts/filtering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bot/exts/filtering/test_filters.py b/tests/bot/exts/filtering/test_filters.py new file mode 100644 index 000000000..214637b52 --- /dev/null +++ b/tests/bot/exts/filtering/test_filters.py @@ -0,0 +1,41 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class FilterTests(unittest.TestCase): + """Test functionality of the token filter.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + def test_token_filter_triggers(self): + """The filter should evaluate to True only if its token is found in the context content.""" + test_cases = ( + (r"hi", "oh hi there", True), + (r"hi", "goodbye", False), + (r"bla\d{2,4}", "bla18", True), + (r"bla\d{2,4}", "bla1", False) + ) + + for pattern, content, expected in test_cases: + with self.subTest( + pattern=pattern, + content=content, + expected=expected, + ): + filter_ = TokenFilter({ + "id": 1, + "content": pattern, + "description": None, + "settings": {}, + "additional_field": "{}" # noqa: P103 + }) + self.ctx.content = content + result = filter_.triggered_on(self.ctx) + self.assertEqual(result, expected) diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py new file mode 100644 index 000000000..ac21a5d47 --- /dev/null +++ b/tests/bot/exts/filtering/test_settings.py @@ -0,0 +1,20 @@ +import unittest + +import bot.exts.filtering._settings +from bot.exts.filtering._settings import create_settings + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def test_create_settings_returns_none_for_empty_data(self): + """`create_settings` should return a tuple of two Nones when passed an empty dict.""" + result = create_settings({}) + + self.assertEquals(result, (None, None)) + + def test_unrecognized_entry_makes_a_warning(self): + """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" + create_settings({"abcd": {}}) + + self.assertIn("abcd", bot.exts.filtering._settings._already_warned) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py new file mode 100644 index 000000000..4db6438ab --- /dev/null +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -0,0 +1,272 @@ +import unittest + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.filter_dm import FilterDM +from bot.exts.filtering._settings_types.infraction_and_notification import ( + Infraction, InfractionAndNotification, superstar +) +from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel + + +class FilterTests(unittest.TestCase): + """Test functionality of the Settings class and its subclasses.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + def test_role_bypass_is_off_for_user_without_roles(self): + """The role bypass should trigger when a user has no roles.""" + member = MockMember() + self.ctx.author = member + bypass_entry = RoleBypass(["123"]) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_role_bypass_is_on_for_a_user_with_the_right_role(self): + """The role bypass should not trigger when the user has one of its roles.""" + cases = ( + ([123], ["123"]), + ([123, 234], ["123"]), + ([123], ["123", "234"]), + ([123, 234], ["123", "234"]) + ) + + for user_role_ids, bypasses in cases: + with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses): + user_roles = [MockRole(id=role_id) for role_id in user_role_ids] + member = MockMember(roles=user_roles) + self.ctx.author = member + bypass_entry = RoleBypass(bypasses) + + result = bypass_entry.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_for_empty_channel_scope(self): + """A filter is enabled for all channels by default.""" + channel = MockTextChannel() + scope = ChannelScope({"disabled_channels": None, "disabled_categories": None, "enabled_channels": None}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_doesnt_trigger_for_disabled_channel(self): + """A filter shouldn't trigger if it's been disabled in the channel.""" + channel = MockTextChannel(id=123) + scope = ChannelScope({"disabled_channels": [123], "disabled_categories": None, "enabled_channels": None}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_in_disabled_category(self): + """A filter shouldn't trigger if it's been disabled in the category.""" + channel = MockTextChannel() + scope = ChannelScope({ + "disabled_channels": None, "disabled_categories": [channel.category.id], "enabled_channels": None + }) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_triggers_in_enabled_channel_in_disabled_category(self): + """A filter should trigger in an enabled channel even if it's been disabled in the category.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope({"disabled_channels": None, "disabled_categories": [234], "enabled_channels": [123]}) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_filtering_dms_when_necessary(self): + """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" + cases = ( + (True, MockDMChannel(), True), + (False, MockDMChannel(), False), + (True, MockTextChannel(), True), + (False, MockTextChannel(), True) + ) + + for apply_in_dms, channel, expected in cases: + with self.subTest(apply_in_dms=apply_in_dms, channel=channel): + filter_dms = FilterDM(apply_in_dms) + self.ctx.channel = channel + + result = filter_dms.triggers_on(self.ctx) + + self.assertEqual(expected, result) + + def test_infraction_merge_of_same_infraction_type(self): + """When both infractions are of the same type, the one with the longer duration wins.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 10, + "dm_content": "how", + "dm_embed": "what is" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "there", + "infraction_duration": 20, + "dm_content": "are you", + "dm_embed": "your name" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.MUTE, + "infraction_reason": "there", + "infraction_duration": 20.0, + "dm_content": "are you", + "dm_embed": "your name", + "_superstar": None + } + ) + + def test_infraction_merge_of_different_infraction_types(self): + """If there are two different infraction types, the one higher up the hierarchy should be picked.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "ban", + "infraction_reason": "", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.BAN, + "infraction_reason": "", + "infraction_duration": 10.0, + "dm_content": "there", + "dm_embed": "", + "_superstar": None + } + ) + + def test_infraction_merge_with_a_superstar(self): + """If there is a superstar infraction, it should be added to a separate field.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "mute", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "there", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "hello", + "infraction_duration": 10, + "dm_content": "you", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.MUTE, + "infraction_reason": "hi", + "infraction_duration": 20.0, + "dm_content": "there", + "dm_embed": "", + "_superstar": superstar("hello", 10.0) + } + ) + + def test_merge_two_superstar_infractions(self): + """When two superstar infractions are merged, the infraction type remains a superstar.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.SUPERSTAR, + "infraction_reason": "hi", + "infraction_duration": 20.0, + "dm_content": "", + "dm_embed": "", + "_superstar": None + } + ) + + def test_merge_a_voiceban_and_a_superstar_with_another_superstar(self): + """An infraction with a superstar merged with a superstar should combine under `_superstar`.""" + infraction1 = InfractionAndNotification({ + "infraction_type": "voice ban", + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "hello", + "dm_embed": "" + }) + infraction2 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "bla", + "infraction_duration": 10, + "dm_content": "there", + "dm_embed": "" + }) + infraction3 = InfractionAndNotification({ + "infraction_type": "superstar", + "infraction_reason": "blabla", + "infraction_duration": 20, + "dm_content": "there", + "dm_embed": "" + }) + + result = infraction1 | infraction2 | infraction3 + + self.assertDictEqual( + result.to_dict(), + { + "infraction_type": Infraction.VOICE_BAN, + "infraction_reason": "hi", + "infraction_duration": 20, + "dm_content": "hello", + "dm_embed": "", + "_superstar": superstar("blabla", 20) + } + ) -- cgit v1.2.3 From d1ae7ce9235e4d63ee1dde282ca890ac5509f950 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Tue, 22 Feb 2022 22:15:19 +0200 Subject: Accept strings in channel scope and change role string interpretation The channel scope settings were changed to accomodate strings. That means that if a string is specified, the bot will look whether the context channel's name matches. If it's a number, it will match the ID. Accordingly the same changed was applied to the bypass roles and pings settings: if it's a non-numeric string, it will look for a role with that name. --- bot/exts/filtering/_settings_types/bypass_roles.py | 13 ++++++----- .../filtering/_settings_types/channel_scope.py | 27 ++++++++++++++++------ bot/exts/filtering/_settings_types/ping.py | 26 ++++++++++++--------- bot/exts/filtering/_utils.py | 8 ------- bot/exts/filtering/filtering.py | 2 +- tests/bot/exts/filtering/test_settings_entries.py | 8 +++---- tests/helpers.py | 2 +- 7 files changed, 48 insertions(+), 38 deletions(-) (limited to 'tests') diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py index 9665283ff..bfc4a30fd 100644 --- a/bot/exts/filtering/_settings_types/bypass_roles.py +++ b/bot/exts/filtering/_settings_types/bypass_roles.py @@ -4,7 +4,6 @@ from discord import Member from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry -from bot.exts.filtering._utils import ROLE_LITERALS class RoleBypass(ValidationEntry): @@ -16,14 +15,16 @@ class RoleBypass(ValidationEntry): super().__init__(entry_data) self.roles = set() for role in entry_data: - if role in ROLE_LITERALS: - self.roles.add(ROLE_LITERALS[role]) - elif role.isdigit(): + if role.isdigit(): self.roles.add(int(role)) - # Ignore entries that can't be resolved. + else: + self.roles.add(role) def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered on this user given their roles.""" if not isinstance(ctx.author, Member): return True - return all(member_role.id not in self.roles for member_role in ctx.author.roles) + return all( + member_role.id not in self.roles and member_role.name not in self.roles + for member_role in ctx.author.roles + ) diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py index b17914f2f..63da6c7e5 100644 --- a/bot/exts/filtering/_settings_types/channel_scope.py +++ b/bot/exts/filtering/_settings_types/channel_scope.py @@ -1,9 +1,16 @@ -from typing import Any +from typing import Any, Union from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ValidationEntry +def maybe_cast_to_int(item: str) -> Union[str, int]: + """Cast the item to int if it consists of only digit, or leave as is otherwise.""" + if item.isdigit(): + return int(item) + return item + + class ChannelScope(ValidationEntry): """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" @@ -12,17 +19,17 @@ class ChannelScope(ValidationEntry): def __init__(self, entry_data: Any): super().__init__(entry_data) if entry_data["disabled_channels"]: - self.disabled_channels = set(entry_data["disabled_channels"]) + self.disabled_channels = set(map(maybe_cast_to_int, entry_data["disabled_channels"])) else: self.disabled_channels = set() if entry_data["disabled_categories"]: - self.disabled_categories = set(entry_data["disabled_categories"]) + self.disabled_categories = set(map(maybe_cast_to_int, entry_data["disabled_categories"])) else: self.disabled_categories = set() if entry_data["enabled_channels"]: - self.enabled_channels = set(entry_data["enabled_channels"]) + self.enabled_channels = set(map(maybe_cast_to_int, entry_data["enabled_channels"])) else: self.enabled_channels = set() @@ -34,12 +41,18 @@ class ChannelScope(ValidationEntry): If the channel is explicitly enabled, it bypasses the set disabled channels and categories. """ channel = ctx.channel - if hasattr(channel, "parent"): - channel = channel.parent - return ( + enabled_id = ( channel.id in self.enabled_channels or ( channel.id not in self.disabled_channels and (not channel.category or channel.category.id not in self.disabled_categories) ) ) + enabled_name = ( + channel.name in self.enabled_channels + or ( + channel.name not in self.disabled_channels + and (not channel.category or channel.category.name not in self.disabled_categories) + ) + ) + return enabled_id and enabled_name diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py index 857e4a7e8..0f9a014c4 100644 --- a/bot/exts/filtering/_settings_types/ping.py +++ b/bot/exts/filtering/_settings_types/ping.py @@ -5,7 +5,6 @@ from discord import Guild from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry -from bot.exts.filtering._utils import ROLE_LITERALS class Ping(ActionEntry): @@ -40,13 +39,18 @@ class Ping(ActionEntry): """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" if mention in ("here", "everyone"): return f"@{mention}" - if mention in ROLE_LITERALS: - return f"<@&{ROLE_LITERALS[mention]}>" - if not mention.isdigit(): - return mention - - mention = int(mention) - if any(mention == role.id for role in guild.roles): - return f"<@&{mention}>" - else: - return f"<@{mention}>" + if mention.isdigit(): # It's an ID. + mention = int(mention) + if any(mention == role.id for role in guild.roles): + return f"<@&{mention}>" + else: + return f"<@{mention}>" + + # It's a name + for role in guild.roles: + if role.name == mention: + return role.mention + for member in guild.members: + if str(member) == mention: + return member.mention + return mention diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index 790f70ee5..d09262193 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -8,14 +8,6 @@ from typing import Set import regex -from bot.constants import Roles - -ROLE_LITERALS = { - "admins": Roles.admins, - "onduty": Roles.moderators, - "staff": Roles.helpers -} - VARIATION_SELECTORS = r"\uFE00-\uFE0F\U000E0100-\U000E01EF" INVISIBLE_RE = regex.compile(rf"[{VARIATION_SELECTORS}\p{{UNASSIGNED}}\p{{FORMAT}}\p{{CONTROL}}--\s]", regex.V1) ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIATION_SELECTORS}]]", regex.V1) diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 58e16043a..d34b4928a 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -89,7 +89,7 @@ class Filtering(Cog): @Cog.listener() async def on_message(self, msg: Message) -> None: """Filter the contents of a sent message.""" - if msg.author.bot: + if msg.author.bot or msg.webhook_id: return ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index 4db6438ab..d18861bd6 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -62,7 +62,7 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_for_disabled_channel(self): """A filter shouldn't trigger if it's been disabled in the channel.""" channel = MockTextChannel(id=123) - scope = ChannelScope({"disabled_channels": [123], "disabled_categories": None, "enabled_channels": None}) + scope = ChannelScope({"disabled_channels": ["123"], "disabled_categories": None, "enabled_channels": None}) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -71,9 +71,9 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_in_disabled_category(self): """A filter shouldn't trigger if it's been disabled in the category.""" - channel = MockTextChannel() + channel = MockTextChannel(category=MockCategoryChannel(id=456)) scope = ChannelScope({ - "disabled_channels": None, "disabled_categories": [channel.category.id], "enabled_channels": None + "disabled_channels": None, "disabled_categories": ["456"], "enabled_channels": None }) self.ctx.channel = channel @@ -84,7 +84,7 @@ class FilterTests(unittest.TestCase): def test_context_triggers_in_enabled_channel_in_disabled_category(self): """A filter should trigger in an enabled channel even if it's been disabled in the category.""" channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) - scope = ChannelScope({"disabled_channels": None, "disabled_categories": [234], "enabled_channels": [123]}) + scope = ChannelScope({"disabled_channels": None, "disabled_categories": ["234"], "enabled_channels": ["123"]}) self.ctx.channel = channel result = scope.triggers_on(self.ctx) diff --git a/tests/helpers.py b/tests/helpers.py index 17214553c..e74306d23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel( class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id)} - super().__init__(**collections.ChainMap(default_kwargs, kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) # Create a Message instance to get a realistic MagicMock of `discord.Message` -- cgit v1.2.3 From 72e164c38fed8d02fbe58412cf3a6de6e38aec09 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Fri, 30 Sep 2022 23:45:49 +0300 Subject: Split actions and validations to their own packcages This is a purely aesthetic choice. Additionally fixes a small bug where a missing entry type would repeatedly invoke a warning on cog load. --- bot/exts/filtering/_settings.py | 2 +- bot/exts/filtering/_settings_types/__init__.py | 9 +- .../filtering/_settings_types/actions/__init__.py | 8 ++ .../_settings_types/actions/delete_messages.py | 35 ++++++ .../actions/infraction_and_notification.py | 137 +++++++++++++++++++++ bot/exts/filtering/_settings_types/actions/ping.py | 70 +++++++++++ .../_settings_types/actions/send_alert.py | 24 ++++ bot/exts/filtering/_settings_types/bypass_roles.py | 33 ----- .../filtering/_settings_types/channel_scope.py | 66 ---------- .../filtering/_settings_types/delete_messages.py | 35 ------ bot/exts/filtering/_settings_types/enabled.py | 19 --- bot/exts/filtering/_settings_types/filter_dm.py | 17 --- .../_settings_types/infraction_and_notification.py | 137 --------------------- bot/exts/filtering/_settings_types/ping.py | 70 ----------- bot/exts/filtering/_settings_types/send_alert.py | 24 ---- .../_settings_types/validations/__init__.py | 8 ++ .../_settings_types/validations/bypass_roles.py | 33 +++++ .../_settings_types/validations/channel_scope.py | 66 ++++++++++ .../_settings_types/validations/enabled.py | 19 +++ .../_settings_types/validations/filter_dm.py | 17 +++ tests/bot/exts/filtering/test_settings_entries.py | 8 +- 21 files changed, 424 insertions(+), 413 deletions(-) create mode 100644 bot/exts/filtering/_settings_types/actions/__init__.py create mode 100644 bot/exts/filtering/_settings_types/actions/delete_messages.py create mode 100644 bot/exts/filtering/_settings_types/actions/infraction_and_notification.py create mode 100644 bot/exts/filtering/_settings_types/actions/ping.py create mode 100644 bot/exts/filtering/_settings_types/actions/send_alert.py delete mode 100644 bot/exts/filtering/_settings_types/bypass_roles.py delete mode 100644 bot/exts/filtering/_settings_types/channel_scope.py delete mode 100644 bot/exts/filtering/_settings_types/delete_messages.py delete mode 100644 bot/exts/filtering/_settings_types/enabled.py delete mode 100644 bot/exts/filtering/_settings_types/filter_dm.py delete mode 100644 bot/exts/filtering/_settings_types/infraction_and_notification.py delete mode 100644 bot/exts/filtering/_settings_types/ping.py delete mode 100644 bot/exts/filtering/_settings_types/send_alert.py create mode 100644 bot/exts/filtering/_settings_types/validations/__init__.py create mode 100644 bot/exts/filtering/_settings_types/validations/bypass_roles.py create mode 100644 bot/exts/filtering/_settings_types/validations/channel_scope.py create mode 100644 bot/exts/filtering/_settings_types/validations/enabled.py create mode 100644 bot/exts/filtering/_settings_types/validations/filter_dm.py (limited to 'tests') diff --git a/bot/exts/filtering/_settings.py b/bot/exts/filtering/_settings.py index f88b26ee3..cbd682d6d 100644 --- a/bot/exts/filtering/_settings.py +++ b/bot/exts/filtering/_settings.py @@ -31,7 +31,7 @@ def create_settings( action_data[entry_name] = entry_data elif entry_name in settings_types["ValidationEntry"]: validation_data[entry_name] = entry_data - else: + elif entry_name not in _already_warned: log.warning( f"A setting named {entry_name} was loaded from the database, but no matching class." ) diff --git a/bot/exts/filtering/_settings_types/__init__.py b/bot/exts/filtering/_settings_types/__init__.py index 620290cb2..61b5737d4 100644 --- a/bot/exts/filtering/_settings_types/__init__.py +++ b/bot/exts/filtering/_settings_types/__init__.py @@ -1,10 +1,5 @@ -from os.path import dirname - -from bot.exts.filtering._settings_types.settings_entry import ActionEntry, ValidationEntry -from bot.exts.filtering._utils import subclasses_in_package - -action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry) -validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry) +from bot.exts.filtering._settings_types.actions import action_types +from bot.exts.filtering._settings_types.validations import validation_types settings_types = { "ActionEntry": {settings_type.name: settings_type for settings_type in action_types}, diff --git a/bot/exts/filtering/_settings_types/actions/__init__.py b/bot/exts/filtering/_settings_types/actions/__init__.py new file mode 100644 index 000000000..a8175b976 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ActionEntry +from bot.exts.filtering._utils import subclasses_in_package + +action_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ActionEntry) + +__all__ = [action_types] diff --git a/bot/exts/filtering/_settings_types/actions/delete_messages.py b/bot/exts/filtering/_settings_types/actions/delete_messages.py new file mode 100644 index 000000000..710cb0ed8 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/delete_messages.py @@ -0,0 +1,35 @@ +from contextlib import suppress +from typing import ClassVar + +from discord.errors import NotFound + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class DeleteMessages(ActionEntry): + """A setting entry which tells whether to delete the offending message(s).""" + + name: ClassVar[str] = "delete_messages" + description: ClassVar[str] = ( + "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." + ) + + delete_messages: bool + + async def action(self, ctx: FilterContext) -> None: + """Delete the context message(s).""" + if not self.delete_messages or ctx.event not in (Event.MESSAGE, Event.MESSAGE_EDIT): + return + + with suppress(NotFound): + if ctx.message.guild: + await ctx.message.delete() + ctx.action_descriptions.append("deleted") + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, DeleteMessages): + return NotImplemented + + return DeleteMessages(delete_messages=self.delete_messages or other.delete_messages) diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py new file mode 100644 index 000000000..4fcf2aa65 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -0,0 +1,137 @@ +from datetime import timedelta +from enum import Enum, auto +from typing import ClassVar + +import arrow +from discord import Colour, Embed +from discord.errors import Forbidden +from pydantic import validator + +import bot +from bot.constants import Channels, Guild +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class Infraction(Enum): + """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" + + BAN = auto() + KICK = auto() + MUTE = auto() + VOICE_MUTE = auto() + SUPERSTAR = auto() + WARNING = auto() + WATCH = auto() + NOTE = auto() + + def __str__(self) -> str: + return self.name + + +class InfractionAndNotification(ActionEntry): + """ + A setting entry which specifies what infraction to issue and the notification to DM the user. + + Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. + """ + + name: ClassVar[str] = "infraction_and_notification" + description: ClassVar[dict[str, str]] = { + "infraction_type": ( + "The type of infraction to issue when the filter triggers, or 'NONE'. " + "If two infractions are triggered for the same message, " + "the harsher one will be applied (by type or duration). " + "Superstars will be triggered even if there is a harsher infraction.\n\n" + "Valid infraction types in order of harshness: " + ) + ", ".join(infraction.name for infraction in Infraction), + "infraction_duration": "How long the infraction should last for in seconds, or 'None' for permanent.", + "infraction_reason": "The reason delivered with the infraction.", + "dm_content": "The contents of a message to be DMed to the offending user.", + "dm_embed": "The contents of the embed to be DMed to the offending user." + } + + dm_content: str | None + dm_embed: str | None + infraction_type: Infraction | None + infraction_reason: str | None + infraction_duration: float | None + + @validator("infraction_type", pre=True) + @classmethod + def convert_infraction_name(cls, infr_type: str) -> Infraction: + """Convert the string to an Infraction by name.""" + return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else None + + async def action(self, ctx: FilterContext) -> None: + """Send the notification to the user, and apply any specified infractions.""" + # If there is no infraction to apply, any DM contents already provided in the context take precedence. + if self.infraction_type is None and (ctx.dm_content or ctx.dm_embed): + dm_content = ctx.dm_content + dm_embed = ctx.dm_embed + else: + dm_content = self.dm_content + dm_embed = self.dm_embed + + if dm_content or dm_embed: + formatting = {"domain": ctx.notification_domain} + dm_content = f"Hey {ctx.author.mention}!\n{dm_content.format(**formatting)}" + if dm_embed: + dm_embed = Embed(description=dm_embed.format(**formatting), colour=Colour.og_blurple()) + else: + dm_embed = None + + try: + await ctx.author.send(dm_content, embed=dm_embed) + ctx.action_descriptions.append("notified") + except Forbidden: + ctx.action_descriptions.append("notified (failed)") + + msg_ctx = await bot.instance.get_context(ctx.message) + msg_ctx.guild = bot.instance.get_guild(Guild.id) + msg_ctx.author = ctx.author + msg_ctx.channel = ctx.channel + + if self.infraction_type is not None: + if self.infraction_type == Infraction.BAN or not hasattr(ctx.channel, "guild"): + msg_ctx.channel = bot.instance.get_channel(Channels.mod_alerts) + msg_ctx.command = bot.instance.get_command(self.infraction_type.name.lower()) + await msg_ctx.invoke( + msg_ctx.command, + ctx.author, + arrow.utcnow() + timedelta(seconds=self.infraction_duration) + if self.infraction_duration is not None else None, + reason=self.infraction_reason + ) + ctx.action_descriptions.append(self.infraction_type.name.lower()) + + def __or__(self, other: ActionEntry): + """ + Combines two actions of the same type. Each type of action is executed once per filter. + + If the infractions are different, take the data of the one higher up the hierarchy. + + There is no clear way to properly combine several notification messages, especially when it's in two parts. + To avoid bombarding the user with several notifications, the message with the more significant infraction + is used. + """ + if not isinstance(other, InfractionAndNotification): + return NotImplemented + + # Lower number -> higher in the hierarchy + if self.infraction_type is None: + return other.copy() + elif other.infraction_type is None: + return self.copy() + elif self.infraction_type.value < other.infraction_type.value: + return self.copy() + elif self.infraction_type.value > other.infraction_type.value: + return other.copy() + else: + if self.infraction_duration is None or ( + other.infraction_duration is not None and self.infraction_duration > other.infraction_duration + ): + result = self.copy() + else: + result = other.copy() + return result diff --git a/bot/exts/filtering/_settings_types/actions/ping.py b/bot/exts/filtering/_settings_types/actions/ping.py new file mode 100644 index 000000000..0bfc12809 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/ping.py @@ -0,0 +1,70 @@ +from functools import cache +from typing import ClassVar + +from discord import Guild +from pydantic import validator + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class Ping(ActionEntry): + """A setting entry which adds the appropriate pings to the alert.""" + + name: ClassVar[str] = "mentions" + description: ClassVar[dict[str, str]] = { + "guild_pings": ( + "A list of role IDs/role names/user IDs/user names/here/everyone. " + "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged." + ), + "dm_pings": ( + "A list of role IDs/role names/user IDs/user names/here/everyone. " + "If a mod-alert is generated for a filter triggered in DMs, these will be pinged." + ) + } + + guild_pings: set[str] + dm_pings: set[str] + + @validator("*") + @classmethod + def init_sequence_if_none(cls, pings: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if pings is None: + return [] + return pings + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + mentions = self.guild_pings if ctx.channel.guild else self.dm_pings + new_content = " ".join([self._resolve_mention(mention, ctx.channel.guild) for mention in mentions]) + ctx.alert_content = f"{new_content} {ctx.alert_content}" + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, Ping): + return NotImplemented + + return Ping(guild_pings=self.guild_pings | other.guild_pings, dm_pings=self.dm_pings | other.dm_pings) + + @staticmethod + @cache + def _resolve_mention(mention: str, guild: Guild) -> str: + """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" + if mention in ("here", "everyone"): + return f"@{mention}" + if mention.isdigit(): # It's an ID. + mention = int(mention) + if any(mention == role.id for role in guild.roles): + return f"<@&{mention}>" + else: + return f"<@{mention}>" + + # It's a name + for role in guild.roles: + if role.name == mention: + return role.mention + for member in guild.members: + if str(member) == mention: + return member.mention + return mention diff --git a/bot/exts/filtering/_settings_types/actions/send_alert.py b/bot/exts/filtering/_settings_types/actions/send_alert.py new file mode 100644 index 000000000..04e400764 --- /dev/null +++ b/bot/exts/filtering/_settings_types/actions/send_alert.py @@ -0,0 +1,24 @@ +from typing import ClassVar + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ActionEntry + + +class SendAlert(ActionEntry): + """A setting entry which tells whether to send an alert message.""" + + name: ClassVar[str] = "send_alert" + description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created." + + send_alert: bool + + async def action(self, ctx: FilterContext) -> None: + """Add the stored pings to the alert message content.""" + ctx.send_alert = self.send_alert + + def __or__(self, other: ActionEntry): + """Combines two actions of the same type. Each type of action is executed once per filter.""" + if not isinstance(other, SendAlert): + return NotImplemented + + return SendAlert(send_alert=self.send_alert or other.send_alert) diff --git a/bot/exts/filtering/_settings_types/bypass_roles.py b/bot/exts/filtering/_settings_types/bypass_roles.py deleted file mode 100644 index a5c18cffc..000000000 --- a/bot/exts/filtering/_settings_types/bypass_roles.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import ClassVar, Union - -from discord import Member -from pydantic import validator - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ValidationEntry - - -class RoleBypass(ValidationEntry): - """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" - - name: ClassVar[str] = "bypass_roles" - description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter." - - bypass_roles: set[Union[int, str]] - - @validator("bypass_roles", each_item=True) - @classmethod - def maybe_cast_to_int(cls, role: str) -> Union[int, str]: - """If the string is alphanumeric, cast it to int.""" - if role.isdigit(): - return int(role) - return role - - def triggers_on(self, ctx: FilterContext) -> bool: - """Return whether the filter should be triggered on this user given their roles.""" - if not isinstance(ctx.author, Member): - return True - return all( - member_role.id not in self.bypass_roles and member_role.name not in self.bypass_roles - for member_role in ctx.author.roles - ) diff --git a/bot/exts/filtering/_settings_types/channel_scope.py b/bot/exts/filtering/_settings_types/channel_scope.py deleted file mode 100644 index fd5206b81..000000000 --- a/bot/exts/filtering/_settings_types/channel_scope.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import ClassVar, Union - -from pydantic import validator - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ValidationEntry - - -class ChannelScope(ValidationEntry): - """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" - - name: ClassVar[str] = "channel_scope" - description: ClassVar[str] = { - "disabled_channels": "A list of channel IDs or channel names. The filter will not trigger in these channels.", - "disabled_categories": ( - "A list of category IDs or category names. The filter will not trigger in these categories." - ), - "enabled_channels": ( - "A list of channel IDs or channel names. " - "The filter can trigger in these channels even if the category is disabled." - ) - } - - disabled_channels: set[Union[str, int]] - disabled_categories: set[Union[str, int]] - enabled_channels: set[Union[str, int]] - - @validator("*", pre=True) - @classmethod - def init_if_sequence_none(cls, sequence: list[str]) -> list[str]: - """Initialize an empty sequence if the value is None.""" - if sequence is None: - return [] - return sequence - - @validator("*", each_item=True) - @classmethod - def maybe_cast_items(cls, channel_or_category: str) -> Union[str, int]: - """Cast to int each value in each sequence if it is alphanumeric.""" - if channel_or_category.isdigit(): - return int(channel_or_category) - return channel_or_category - - def triggers_on(self, ctx: FilterContext) -> bool: - """ - Return whether the filter should be triggered in the given channel. - - The filter is invoked by default. - If the channel is explicitly enabled, it bypasses the set disabled channels and categories. - """ - channel = ctx.channel - enabled_id = ( - channel.id in self.enabled_channels - or ( - channel.id not in self.disabled_channels - and (not channel.category or channel.category.id not in self.disabled_categories) - ) - ) - enabled_name = ( - channel.name in self.enabled_channels - or ( - channel.name not in self.disabled_channels - and (not channel.category or channel.category.name not in self.disabled_categories) - ) - ) - return enabled_id and enabled_name diff --git a/bot/exts/filtering/_settings_types/delete_messages.py b/bot/exts/filtering/_settings_types/delete_messages.py deleted file mode 100644 index 710cb0ed8..000000000 --- a/bot/exts/filtering/_settings_types/delete_messages.py +++ /dev/null @@ -1,35 +0,0 @@ -from contextlib import suppress -from typing import ClassVar - -from discord.errors import NotFound - -from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.settings_entry import ActionEntry - - -class DeleteMessages(ActionEntry): - """A setting entry which tells whether to delete the offending message(s).""" - - name: ClassVar[str] = "delete_messages" - description: ClassVar[str] = ( - "A boolean field. If True, the filter being triggered will cause the offending message to be deleted." - ) - - delete_messages: bool - - async def action(self, ctx: FilterContext) -> None: - """Delete the context message(s).""" - if not self.delete_messages or ctx.event not in (Event.MESSAGE, Event.MESSAGE_EDIT): - return - - with suppress(NotFound): - if ctx.message.guild: - await ctx.message.delete() - ctx.action_descriptions.append("deleted") - - def __or__(self, other: ActionEntry): - """Combines two actions of the same type. Each type of action is executed once per filter.""" - if not isinstance(other, DeleteMessages): - return NotImplemented - - return DeleteMessages(delete_messages=self.delete_messages or other.delete_messages) diff --git a/bot/exts/filtering/_settings_types/enabled.py b/bot/exts/filtering/_settings_types/enabled.py deleted file mode 100644 index 3b5e3e446..000000000 --- a/bot/exts/filtering/_settings_types/enabled.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import ClassVar - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ValidationEntry - - -class Enabled(ValidationEntry): - """A setting entry which tells whether the filter is enabled.""" - - name: ClassVar[str] = "enabled" - description: ClassVar[str] = ( - "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." - ) - - enabled: bool - - def triggers_on(self, ctx: FilterContext) -> bool: - """Return whether the filter is enabled.""" - return self.enabled diff --git a/bot/exts/filtering/_settings_types/filter_dm.py b/bot/exts/filtering/_settings_types/filter_dm.py deleted file mode 100644 index 93022320f..000000000 --- a/bot/exts/filtering/_settings_types/filter_dm.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import ClassVar - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ValidationEntry - - -class FilterDM(ValidationEntry): - """A setting entry which tells whether to apply the filter to DMs.""" - - name: ClassVar[str] = "filter_dm" - description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." - - filter_dm: bool - - def triggers_on(self, ctx: FilterContext) -> bool: - """Return whether the filter should be triggered even if it was triggered in DMs.""" - return hasattr(ctx.channel, "guild") or self.filter_dm diff --git a/bot/exts/filtering/_settings_types/infraction_and_notification.py b/bot/exts/filtering/_settings_types/infraction_and_notification.py deleted file mode 100644 index 4fcf2aa65..000000000 --- a/bot/exts/filtering/_settings_types/infraction_and_notification.py +++ /dev/null @@ -1,137 +0,0 @@ -from datetime import timedelta -from enum import Enum, auto -from typing import ClassVar - -import arrow -from discord import Colour, Embed -from discord.errors import Forbidden -from pydantic import validator - -import bot -from bot.constants import Channels, Guild -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ActionEntry - - -class Infraction(Enum): - """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" - - BAN = auto() - KICK = auto() - MUTE = auto() - VOICE_MUTE = auto() - SUPERSTAR = auto() - WARNING = auto() - WATCH = auto() - NOTE = auto() - - def __str__(self) -> str: - return self.name - - -class InfractionAndNotification(ActionEntry): - """ - A setting entry which specifies what infraction to issue and the notification to DM the user. - - Since a DM cannot be sent when a user is banned or kicked, these two functions need to be grouped together. - """ - - name: ClassVar[str] = "infraction_and_notification" - description: ClassVar[dict[str, str]] = { - "infraction_type": ( - "The type of infraction to issue when the filter triggers, or 'NONE'. " - "If two infractions are triggered for the same message, " - "the harsher one will be applied (by type or duration). " - "Superstars will be triggered even if there is a harsher infraction.\n\n" - "Valid infraction types in order of harshness: " - ) + ", ".join(infraction.name for infraction in Infraction), - "infraction_duration": "How long the infraction should last for in seconds, or 'None' for permanent.", - "infraction_reason": "The reason delivered with the infraction.", - "dm_content": "The contents of a message to be DMed to the offending user.", - "dm_embed": "The contents of the embed to be DMed to the offending user." - } - - dm_content: str | None - dm_embed: str | None - infraction_type: Infraction | None - infraction_reason: str | None - infraction_duration: float | None - - @validator("infraction_type", pre=True) - @classmethod - def convert_infraction_name(cls, infr_type: str) -> Infraction: - """Convert the string to an Infraction by name.""" - return Infraction[infr_type.replace(" ", "_").upper()] if infr_type else None - - async def action(self, ctx: FilterContext) -> None: - """Send the notification to the user, and apply any specified infractions.""" - # If there is no infraction to apply, any DM contents already provided in the context take precedence. - if self.infraction_type is None and (ctx.dm_content or ctx.dm_embed): - dm_content = ctx.dm_content - dm_embed = ctx.dm_embed - else: - dm_content = self.dm_content - dm_embed = self.dm_embed - - if dm_content or dm_embed: - formatting = {"domain": ctx.notification_domain} - dm_content = f"Hey {ctx.author.mention}!\n{dm_content.format(**formatting)}" - if dm_embed: - dm_embed = Embed(description=dm_embed.format(**formatting), colour=Colour.og_blurple()) - else: - dm_embed = None - - try: - await ctx.author.send(dm_content, embed=dm_embed) - ctx.action_descriptions.append("notified") - except Forbidden: - ctx.action_descriptions.append("notified (failed)") - - msg_ctx = await bot.instance.get_context(ctx.message) - msg_ctx.guild = bot.instance.get_guild(Guild.id) - msg_ctx.author = ctx.author - msg_ctx.channel = ctx.channel - - if self.infraction_type is not None: - if self.infraction_type == Infraction.BAN or not hasattr(ctx.channel, "guild"): - msg_ctx.channel = bot.instance.get_channel(Channels.mod_alerts) - msg_ctx.command = bot.instance.get_command(self.infraction_type.name.lower()) - await msg_ctx.invoke( - msg_ctx.command, - ctx.author, - arrow.utcnow() + timedelta(seconds=self.infraction_duration) - if self.infraction_duration is not None else None, - reason=self.infraction_reason - ) - ctx.action_descriptions.append(self.infraction_type.name.lower()) - - def __or__(self, other: ActionEntry): - """ - Combines two actions of the same type. Each type of action is executed once per filter. - - If the infractions are different, take the data of the one higher up the hierarchy. - - There is no clear way to properly combine several notification messages, especially when it's in two parts. - To avoid bombarding the user with several notifications, the message with the more significant infraction - is used. - """ - if not isinstance(other, InfractionAndNotification): - return NotImplemented - - # Lower number -> higher in the hierarchy - if self.infraction_type is None: - return other.copy() - elif other.infraction_type is None: - return self.copy() - elif self.infraction_type.value < other.infraction_type.value: - return self.copy() - elif self.infraction_type.value > other.infraction_type.value: - return other.copy() - else: - if self.infraction_duration is None or ( - other.infraction_duration is not None and self.infraction_duration > other.infraction_duration - ): - result = self.copy() - else: - result = other.copy() - return result diff --git a/bot/exts/filtering/_settings_types/ping.py b/bot/exts/filtering/_settings_types/ping.py deleted file mode 100644 index 0bfc12809..000000000 --- a/bot/exts/filtering/_settings_types/ping.py +++ /dev/null @@ -1,70 +0,0 @@ -from functools import cache -from typing import ClassVar - -from discord import Guild -from pydantic import validator - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ActionEntry - - -class Ping(ActionEntry): - """A setting entry which adds the appropriate pings to the alert.""" - - name: ClassVar[str] = "mentions" - description: ClassVar[dict[str, str]] = { - "guild_pings": ( - "A list of role IDs/role names/user IDs/user names/here/everyone. " - "If a mod-alert is generated for a filter triggered in a public channel, these will be pinged." - ), - "dm_pings": ( - "A list of role IDs/role names/user IDs/user names/here/everyone. " - "If a mod-alert is generated for a filter triggered in DMs, these will be pinged." - ) - } - - guild_pings: set[str] - dm_pings: set[str] - - @validator("*") - @classmethod - def init_sequence_if_none(cls, pings: list[str]) -> list[str]: - """Initialize an empty sequence if the value is None.""" - if pings is None: - return [] - return pings - - async def action(self, ctx: FilterContext) -> None: - """Add the stored pings to the alert message content.""" - mentions = self.guild_pings if ctx.channel.guild else self.dm_pings - new_content = " ".join([self._resolve_mention(mention, ctx.channel.guild) for mention in mentions]) - ctx.alert_content = f"{new_content} {ctx.alert_content}" - - def __or__(self, other: ActionEntry): - """Combines two actions of the same type. Each type of action is executed once per filter.""" - if not isinstance(other, Ping): - return NotImplemented - - return Ping(guild_pings=self.guild_pings | other.guild_pings, dm_pings=self.dm_pings | other.dm_pings) - - @staticmethod - @cache - def _resolve_mention(mention: str, guild: Guild) -> str: - """Return the appropriate formatting for the formatting, be it a literal, a user ID, or a role ID.""" - if mention in ("here", "everyone"): - return f"@{mention}" - if mention.isdigit(): # It's an ID. - mention = int(mention) - if any(mention == role.id for role in guild.roles): - return f"<@&{mention}>" - else: - return f"<@{mention}>" - - # It's a name - for role in guild.roles: - if role.name == mention: - return role.mention - for member in guild.members: - if str(member) == mention: - return member.mention - return mention diff --git a/bot/exts/filtering/_settings_types/send_alert.py b/bot/exts/filtering/_settings_types/send_alert.py deleted file mode 100644 index 04e400764..000000000 --- a/bot/exts/filtering/_settings_types/send_alert.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import ClassVar - -from bot.exts.filtering._filter_context import FilterContext -from bot.exts.filtering._settings_types.settings_entry import ActionEntry - - -class SendAlert(ActionEntry): - """A setting entry which tells whether to send an alert message.""" - - name: ClassVar[str] = "send_alert" - description: ClassVar[str] = "A boolean. If all filters triggered set this to False, no mod-alert will be created." - - send_alert: bool - - async def action(self, ctx: FilterContext) -> None: - """Add the stored pings to the alert message content.""" - ctx.send_alert = self.send_alert - - def __or__(self, other: ActionEntry): - """Combines two actions of the same type. Each type of action is executed once per filter.""" - if not isinstance(other, SendAlert): - return NotImplemented - - return SendAlert(send_alert=self.send_alert or other.send_alert) diff --git a/bot/exts/filtering/_settings_types/validations/__init__.py b/bot/exts/filtering/_settings_types/validations/__init__.py new file mode 100644 index 000000000..5c44e8b27 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname + +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry +from bot.exts.filtering._utils import subclasses_in_package + +validation_types = subclasses_in_package(dirname(__file__), f"{__name__}.", ValidationEntry) + +__all__ = [validation_types] diff --git a/bot/exts/filtering/_settings_types/validations/bypass_roles.py b/bot/exts/filtering/_settings_types/validations/bypass_roles.py new file mode 100644 index 000000000..a5c18cffc --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/bypass_roles.py @@ -0,0 +1,33 @@ +from typing import ClassVar, Union + +from discord import Member +from pydantic import validator + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class RoleBypass(ValidationEntry): + """A setting entry which tells whether the roles the member has allow them to bypass the filter.""" + + name: ClassVar[str] = "bypass_roles" + description: ClassVar[str] = "A list of role IDs or role names. Users with these roles will not trigger the filter." + + bypass_roles: set[Union[int, str]] + + @validator("bypass_roles", each_item=True) + @classmethod + def maybe_cast_to_int(cls, role: str) -> Union[int, str]: + """If the string is alphanumeric, cast it to int.""" + if role.isdigit(): + return int(role) + return role + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered on this user given their roles.""" + if not isinstance(ctx.author, Member): + return True + return all( + member_role.id not in self.bypass_roles and member_role.name not in self.bypass_roles + for member_role in ctx.author.roles + ) diff --git a/bot/exts/filtering/_settings_types/validations/channel_scope.py b/bot/exts/filtering/_settings_types/validations/channel_scope.py new file mode 100644 index 000000000..fd5206b81 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/channel_scope.py @@ -0,0 +1,66 @@ +from typing import ClassVar, Union + +from pydantic import validator + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class ChannelScope(ValidationEntry): + """A setting entry which tells whether the filter was invoked in a whitelisted channel or category.""" + + name: ClassVar[str] = "channel_scope" + description: ClassVar[str] = { + "disabled_channels": "A list of channel IDs or channel names. The filter will not trigger in these channels.", + "disabled_categories": ( + "A list of category IDs or category names. The filter will not trigger in these categories." + ), + "enabled_channels": ( + "A list of channel IDs or channel names. " + "The filter can trigger in these channels even if the category is disabled." + ) + } + + disabled_channels: set[Union[str, int]] + disabled_categories: set[Union[str, int]] + enabled_channels: set[Union[str, int]] + + @validator("*", pre=True) + @classmethod + def init_if_sequence_none(cls, sequence: list[str]) -> list[str]: + """Initialize an empty sequence if the value is None.""" + if sequence is None: + return [] + return sequence + + @validator("*", each_item=True) + @classmethod + def maybe_cast_items(cls, channel_or_category: str) -> Union[str, int]: + """Cast to int each value in each sequence if it is alphanumeric.""" + if channel_or_category.isdigit(): + return int(channel_or_category) + return channel_or_category + + def triggers_on(self, ctx: FilterContext) -> bool: + """ + Return whether the filter should be triggered in the given channel. + + The filter is invoked by default. + If the channel is explicitly enabled, it bypasses the set disabled channels and categories. + """ + channel = ctx.channel + enabled_id = ( + channel.id in self.enabled_channels + or ( + channel.id not in self.disabled_channels + and (not channel.category or channel.category.id not in self.disabled_categories) + ) + ) + enabled_name = ( + channel.name in self.enabled_channels + or ( + channel.name not in self.disabled_channels + and (not channel.category or channel.category.name not in self.disabled_categories) + ) + ) + return enabled_id and enabled_name diff --git a/bot/exts/filtering/_settings_types/validations/enabled.py b/bot/exts/filtering/_settings_types/validations/enabled.py new file mode 100644 index 000000000..3b5e3e446 --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/enabled.py @@ -0,0 +1,19 @@ +from typing import ClassVar + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class Enabled(ValidationEntry): + """A setting entry which tells whether the filter is enabled.""" + + name: ClassVar[str] = "enabled" + description: ClassVar[str] = ( + "A boolean field. Setting it to False allows disabling the filter without deleting it entirely." + ) + + enabled: bool + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter is enabled.""" + return self.enabled diff --git a/bot/exts/filtering/_settings_types/validations/filter_dm.py b/bot/exts/filtering/_settings_types/validations/filter_dm.py new file mode 100644 index 000000000..93022320f --- /dev/null +++ b/bot/exts/filtering/_settings_types/validations/filter_dm.py @@ -0,0 +1,17 @@ +from typing import ClassVar + +from bot.exts.filtering._filter_context import FilterContext +from bot.exts.filtering._settings_types.settings_entry import ValidationEntry + + +class FilterDM(ValidationEntry): + """A setting entry which tells whether to apply the filter to DMs.""" + + name: ClassVar[str] = "filter_dm" + description: ClassVar[str] = "A boolean field. If True, the filter can trigger for messages sent to the bot in DMs." + + filter_dm: bool + + def triggers_on(self, ctx: FilterContext) -> bool: + """Return whether the filter should be triggered even if it was triggered in DMs.""" + return hasattr(ctx.channel, "guild") or self.filter_dm diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index d18861bd6..8dba5cb26 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -1,12 +1,12 @@ import unittest from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.bypass_roles import RoleBypass -from bot.exts.filtering._settings_types.channel_scope import ChannelScope -from bot.exts.filtering._settings_types.filter_dm import FilterDM -from bot.exts.filtering._settings_types.infraction_and_notification import ( +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( Infraction, InfractionAndNotification, superstar ) +from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass +from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope +from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel -- cgit v1.2.3 From 7d46b1ed1fdec2a052c147a89f69354469bbfd18 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sat, 1 Oct 2022 00:14:21 +0300 Subject: Fix tests --- .../_settings_types/validations/bypass_roles.py | 2 +- .../_settings_types/validations/filter_dm.py | 2 +- tests/bot/exts/filtering/test_settings.py | 2 +- tests/bot/exts/filtering/test_settings_entries.py | 182 +++++---------------- tests/bot/rules/test_mentions.py | 131 --------------- tests/helpers.py | 6 +- 6 files changed, 44 insertions(+), 281 deletions(-) delete mode 100644 tests/bot/rules/test_mentions.py (limited to 'tests') diff --git a/bot/exts/filtering/_settings_types/validations/bypass_roles.py b/bot/exts/filtering/_settings_types/validations/bypass_roles.py index a5c18cffc..c1e6f885d 100644 --- a/bot/exts/filtering/_settings_types/validations/bypass_roles.py +++ b/bot/exts/filtering/_settings_types/validations/bypass_roles.py @@ -15,7 +15,7 @@ class RoleBypass(ValidationEntry): bypass_roles: set[Union[int, str]] - @validator("bypass_roles", each_item=True) + @validator("bypass_roles", pre=True, each_item=True) @classmethod def maybe_cast_to_int(cls, role: str) -> Union[int, str]: """If the string is alphanumeric, cast it to int.""" diff --git a/bot/exts/filtering/_settings_types/validations/filter_dm.py b/bot/exts/filtering/_settings_types/validations/filter_dm.py index 93022320f..b9e566253 100644 --- a/bot/exts/filtering/_settings_types/validations/filter_dm.py +++ b/bot/exts/filtering/_settings_types/validations/filter_dm.py @@ -14,4 +14,4 @@ class FilterDM(ValidationEntry): def triggers_on(self, ctx: FilterContext) -> bool: """Return whether the filter should be triggered even if it was triggered in DMs.""" - return hasattr(ctx.channel, "guild") or self.filter_dm + return ctx.channel.guild is not None or self.filter_dm diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py index ac21a5d47..5a289c1cf 100644 --- a/tests/bot/exts/filtering/test_settings.py +++ b/tests/bot/exts/filtering/test_settings.py @@ -11,7 +11,7 @@ class FilterTests(unittest.TestCase): """`create_settings` should return a tuple of two Nones when passed an empty dict.""" result = create_settings({}) - self.assertEquals(result, (None, None)) + self.assertEqual(result, (None, None)) def test_unrecognized_entry_makes_a_warning(self): """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`.""" diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index 8dba5cb26..34b155d6b 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -1,9 +1,7 @@ import unittest from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( - Infraction, InfractionAndNotification, superstar -) +from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM @@ -23,7 +21,7 @@ class FilterTests(unittest.TestCase): """The role bypass should trigger when a user has no roles.""" member = MockMember() self.ctx.author = member - bypass_entry = RoleBypass(["123"]) + bypass_entry = RoleBypass(bypass_roles=["123"]) result = bypass_entry.triggers_on(self.ctx) @@ -43,7 +41,7 @@ class FilterTests(unittest.TestCase): user_roles = [MockRole(id=role_id) for role_id in user_role_ids] member = MockMember(roles=user_roles) self.ctx.author = member - bypass_entry = RoleBypass(bypasses) + bypass_entry = RoleBypass(bypass_roles=bypasses) result = bypass_entry.triggers_on(self.ctx) @@ -52,7 +50,7 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_for_empty_channel_scope(self): """A filter is enabled for all channels by default.""" channel = MockTextChannel() - scope = ChannelScope({"disabled_channels": None, "disabled_categories": None, "enabled_channels": None}) + scope = ChannelScope(disabled_channels=None, disabled_categories=None, enabled_channels=None) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -62,7 +60,7 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_for_disabled_channel(self): """A filter shouldn't trigger if it's been disabled in the channel.""" channel = MockTextChannel(id=123) - scope = ChannelScope({"disabled_channels": ["123"], "disabled_categories": None, "enabled_channels": None}) + scope = ChannelScope(disabled_channels=["123"], disabled_categories=None, enabled_channels=None) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -72,9 +70,7 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_in_disabled_category(self): """A filter shouldn't trigger if it's been disabled in the category.""" channel = MockTextChannel(category=MockCategoryChannel(id=456)) - scope = ChannelScope({ - "disabled_channels": None, "disabled_categories": ["456"], "enabled_channels": None - }) + scope = ChannelScope(disabled_channels=None, disabled_categories=["456"], enabled_channels=None) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -84,7 +80,7 @@ class FilterTests(unittest.TestCase): def test_context_triggers_in_enabled_channel_in_disabled_category(self): """A filter should trigger in an enabled channel even if it's been disabled in the category.""" channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) - scope = ChannelScope({"disabled_channels": None, "disabled_categories": ["234"], "enabled_channels": ["123"]}) + scope = ChannelScope(disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"]) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -102,7 +98,7 @@ class FilterTests(unittest.TestCase): for apply_in_dms, channel, expected in cases: with self.subTest(apply_in_dms=apply_in_dms, channel=channel): - filter_dms = FilterDM(apply_in_dms) + filter_dms = FilterDM(filter_dm=apply_in_dms) self.ctx.channel = channel result = filter_dms.triggers_on(self.ctx) @@ -111,162 +107,60 @@ class FilterTests(unittest.TestCase): def test_infraction_merge_of_same_infraction_type(self): """When both infractions are of the same type, the one with the longer duration wins.""" - infraction1 = InfractionAndNotification({ - "infraction_type": "mute", - "infraction_reason": "hi", - "infraction_duration": 10, - "dm_content": "how", - "dm_embed": "what is" - }) - infraction2 = InfractionAndNotification({ - "infraction_type": "mute", - "infraction_reason": "there", - "infraction_duration": 20, - "dm_content": "are you", - "dm_embed": "your name" - }) + infraction1 = InfractionAndNotification( + infraction_type="MUTE", + infraction_reason="hi", + infraction_duration=10, + dm_content="how", + dm_embed="what is" + ) + infraction2 = InfractionAndNotification( + infraction_type="MUTE", + infraction_reason="there", + infraction_duration=20, + dm_content="are you", + dm_embed="your name" + ) result = infraction1 | infraction2 self.assertDictEqual( - result.to_dict(), + result.dict(), { "infraction_type": Infraction.MUTE, "infraction_reason": "there", "infraction_duration": 20.0, "dm_content": "are you", "dm_embed": "your name", - "_superstar": None } ) def test_infraction_merge_of_different_infraction_types(self): """If there are two different infraction types, the one higher up the hierarchy should be picked.""" - infraction1 = InfractionAndNotification({ - "infraction_type": "mute", - "infraction_reason": "hi", - "infraction_duration": 20, - "dm_content": "", - "dm_embed": "" - }) - infraction2 = InfractionAndNotification({ - "infraction_type": "ban", - "infraction_reason": "", - "infraction_duration": 10, - "dm_content": "there", - "dm_embed": "" - }) + infraction1 = InfractionAndNotification( + infraction_type="MUTE", + infraction_reason="hi", + infraction_duration=20, + dm_content="", + dm_embed="" + ) + infraction2 = InfractionAndNotification( + infraction_type="BAN", + infraction_reason="", + infraction_duration=10, + dm_content="there", + dm_embed="" + ) result = infraction1 | infraction2 self.assertDictEqual( - result.to_dict(), + result.dict(), { "infraction_type": Infraction.BAN, "infraction_reason": "", "infraction_duration": 10.0, "dm_content": "there", "dm_embed": "", - "_superstar": None - } - ) - - def test_infraction_merge_with_a_superstar(self): - """If there is a superstar infraction, it should be added to a separate field.""" - infraction1 = InfractionAndNotification({ - "infraction_type": "mute", - "infraction_reason": "hi", - "infraction_duration": 20, - "dm_content": "there", - "dm_embed": "" - }) - infraction2 = InfractionAndNotification({ - "infraction_type": "superstar", - "infraction_reason": "hello", - "infraction_duration": 10, - "dm_content": "you", - "dm_embed": "" - }) - - result = infraction1 | infraction2 - - self.assertDictEqual( - result.to_dict(), - { - "infraction_type": Infraction.MUTE, - "infraction_reason": "hi", - "infraction_duration": 20.0, - "dm_content": "there", - "dm_embed": "", - "_superstar": superstar("hello", 10.0) - } - ) - - def test_merge_two_superstar_infractions(self): - """When two superstar infractions are merged, the infraction type remains a superstar.""" - infraction1 = InfractionAndNotification({ - "infraction_type": "superstar", - "infraction_reason": "hi", - "infraction_duration": 20, - "dm_content": "", - "dm_embed": "" - }) - infraction2 = InfractionAndNotification({ - "infraction_type": "superstar", - "infraction_reason": "", - "infraction_duration": 10, - "dm_content": "there", - "dm_embed": "" - }) - - result = infraction1 | infraction2 - - self.assertDictEqual( - result.to_dict(), - { - "infraction_type": Infraction.SUPERSTAR, - "infraction_reason": "hi", - "infraction_duration": 20.0, - "dm_content": "", - "dm_embed": "", - "_superstar": None - } - ) - - def test_merge_a_voiceban_and_a_superstar_with_another_superstar(self): - """An infraction with a superstar merged with a superstar should combine under `_superstar`.""" - infraction1 = InfractionAndNotification({ - "infraction_type": "voice ban", - "infraction_reason": "hi", - "infraction_duration": 20, - "dm_content": "hello", - "dm_embed": "" - }) - infraction2 = InfractionAndNotification({ - "infraction_type": "superstar", - "infraction_reason": "bla", - "infraction_duration": 10, - "dm_content": "there", - "dm_embed": "" - }) - infraction3 = InfractionAndNotification({ - "infraction_type": "superstar", - "infraction_reason": "blabla", - "infraction_duration": 20, - "dm_content": "there", - "dm_embed": "" - }) - - result = infraction1 | infraction2 | infraction3 - - self.assertDictEqual( - result.to_dict(), - { - "infraction_type": Infraction.VOICE_BAN, - "infraction_reason": "hi", - "infraction_duration": 20, - "dm_content": "hello", - "dm_embed": "", - "_superstar": superstar("blabla", 20) } ) diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py deleted file mode 100644 index e1f904917..000000000 --- a/tests/bot/rules/test_mentions.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Iterable, Optional - -import discord - -from bot.rules import mentions -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMember, MockMessage, MockMessageReference - - -def make_msg( - author: str, - total_user_mentions: int, - total_bot_mentions: int = 0, - *, - reference: Optional[MockMessageReference] = None -) -> MockMessage: - """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions.""" - user_mentions = [MockMember() for _ in range(total_user_mentions)] - bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] - - mentions = user_mentions + bot_mentions - if reference is not None: - # For the sake of these tests we assume that all references are mentions. - mentions.append(reference.resolved.author) - msg_type = discord.MessageType.reply - else: - msg_type = discord.MessageType.default - - return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type) - - -class TestMentions(RuleTest): - """Tests applying the `mentions` antispam rule.""" - - def setUp(self): - self.apply = mentions.apply - self.config = { - "max": 2, - "interval": 10, - } - - async def test_mentions_within_limit(self): - """Messages with an allowed amount of mentions.""" - cases = ( - [make_msg("bob", 0)], - [make_msg("bob", 2)], - [make_msg("bob", 1), make_msg("bob", 1)], - [make_msg("bob", 1), make_msg("alice", 2)], - ) - - await self.run_allowed(cases) - - async def test_mentions_exceeding_limit(self): - """Messages with a higher than allowed amount of mentions.""" - cases = ( - DisallowedCase( - [make_msg("bob", 3)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], - ("alice",), - 3, - ), - DisallowedCase( - [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], - ("bob",), - 4, - ), - DisallowedCase( - [make_msg("bob", 3, 1)], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob", 3, reference=MockMessageReference())], - ("bob",), - 3, - ), - DisallowedCase( - [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))], - ("bob",), - 3 - ) - ) - - await self.run_disallowed(cases) - - async def test_ignore_bot_mentions(self): - """Messages with an allowed amount of mentions, also containing bot mentions.""" - cases = ( - [make_msg("bob", 0, 3)], - [make_msg("bob", 2, 1)], - [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], - [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] - ) - - await self.run_allowed(cases) - - async def test_ignore_reply_mentions(self): - """Messages with an allowed amount of mentions in the content, also containing reply mentions.""" - cases = ( - [ - make_msg("bob", 2, reference=MockMessageReference()) - ], - [ - make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)) - ], - [ - make_msg("bob", 2, reference=MockMessageReference()), - make_msg("bob", 0, reference=MockMessageReference()) - ], - [ - make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)), - make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True)) - ] - ) - - await self.run_allowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/helpers.py b/tests/helpers.py index 28a8e40a7..35a8a71f7 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): """ - A MagicMock subclass to mock TextChannel objects. + A MagicMock subclass to mock DMChannel objects. - Instances of this class will follow the specifications of `discord.TextChannel` instances. For + Instances of this class will follow the specifications of `discord.DMChannel` instances. For more information, see the `MockGuild` docstring. """ spec_set = dm_channel_instance def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} + default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None} super().__init__(**collections.ChainMap(kwargs, default_kwargs)) -- cgit v1.2.3 From c20398233a4a792e3207d52765aaf530a468351a Mon Sep 17 00:00:00 2001 From: mbaruh Date: Tue, 1 Nov 2022 20:57:09 +0200 Subject: Add the rest of the antispam rules This is mostly a copy-paste of the implementations in the old system into the new system's structure. The mentions rule required changing the `triggers_on` method to async. --- bot/exts/filtering/_filter_lists/antispam.py | 10 ++- bot/exts/filtering/_filter_lists/domain.py | 2 +- bot/exts/filtering/_filter_lists/extension.py | 4 +- bot/exts/filtering/_filter_lists/filter_list.py | 16 ++-- bot/exts/filtering/_filter_lists/invite.py | 5 +- bot/exts/filtering/_filter_lists/token.py | 2 +- bot/exts/filtering/_filter_lists/unique.py | 2 +- .../filtering/_filters/antispam/attachments.py | 43 +++++++++++ bot/exts/filtering/_filters/antispam/burst.py | 41 ++++++++++ bot/exts/filtering/_filters/antispam/chars.py | 43 +++++++++++ bot/exts/filtering/_filters/antispam/duplicates.py | 4 +- bot/exts/filtering/_filters/antispam/emoji.py | 53 +++++++++++++ bot/exts/filtering/_filters/antispam/links.py | 52 +++++++++++++ bot/exts/filtering/_filters/antispam/mentions.py | 90 ++++++++++++++++++++++ bot/exts/filtering/_filters/antispam/newlines.py | 61 +++++++++++++++ .../filtering/_filters/antispam/role_mentions.py | 42 ++++++++++ bot/exts/filtering/_filters/domain.py | 2 +- bot/exts/filtering/_filters/extension.py | 2 +- bot/exts/filtering/_filters/filter.py | 6 +- bot/exts/filtering/_filters/invite.py | 2 +- bot/exts/filtering/_filters/token.py | 2 +- .../filtering/_filters/unique/discord_token.py | 2 +- bot/exts/filtering/_filters/unique/everyone.py | 2 +- bot/exts/filtering/_filters/unique/rich_embed.py | 2 +- bot/exts/filtering/_filters/unique/webhook.py | 2 +- bot/exts/filtering/filtering.py | 5 +- tests/bot/exts/filtering/test_filters.py | 6 +- 27 files changed, 468 insertions(+), 35 deletions(-) create mode 100644 bot/exts/filtering/_filters/antispam/attachments.py create mode 100644 bot/exts/filtering/_filters/antispam/burst.py create mode 100644 bot/exts/filtering/_filters/antispam/chars.py create mode 100644 bot/exts/filtering/_filters/antispam/emoji.py create mode 100644 bot/exts/filtering/_filters/antispam/links.py create mode 100644 bot/exts/filtering/_filters/antispam/mentions.py create mode 100644 bot/exts/filtering/_filters/antispam/newlines.py create mode 100644 bot/exts/filtering/_filters/antispam/role_mentions.py (limited to 'tests') diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py index 2dab54ce6..b2f873094 100644 --- a/bot/exts/filtering/_filter_lists/antispam.py +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -34,7 +34,9 @@ class AntispamList(UniquesListBase): """ A list of anti-spam rules. - Messages from the last X seconds is passed to each rule, which decide whether it triggers across those messages. + Messages from the last X seconds are passed to each rule, which decides whether it triggers across those messages. + + The infraction reason is set dynamically. """ name = "antispam" @@ -67,7 +69,7 @@ class AntispamList(UniquesListBase): takewhile(lambda msg: msg.created_at > earliest_relevant_at, self.filtering_cog.message_cache) ) new_ctx = ctx.replace(content=relevant_messages) - triggers = sublist.filter_list_result(new_ctx) + triggers = await sublist.filter_list_result(new_ctx) if not triggers: return None, [] @@ -88,7 +90,9 @@ class AntispamList(UniquesListBase): # Smaller infraction value = higher in hierarchy. if not current_infraction or new_infraction.infraction_type.value < current_infraction.value: # Pick the first triggered filter for the reason, there's no good way to decide between them. - new_infraction.infraction_reason = f"{triggers[0].name} spam - {ctx.filter_info[triggers[0]]}" + new_infraction.infraction_reason = ( + f"{triggers[0].name.replace('_', ' ')} spam – {ctx.filter_info[triggers[0]]}" + ) current_actions["infraction_and_notification"] = new_infraction self.message_deletion_queue[ctx.author].current_infraction = new_infraction.infraction_type else: diff --git a/bot/exts/filtering/_filter_lists/domain.py b/bot/exts/filtering/_filter_lists/domain.py index d97aa252b..0b56e8d73 100644 --- a/bot/exts/filtering/_filter_lists/domain.py +++ b/bot/exts/filtering/_filter_lists/domain.py @@ -52,7 +52,7 @@ class DomainsList(FilterList[DomainFilter]): urls = {match.group(1).lower().rstrip("/") for match in URL_RE.finditer(text)} new_ctx = ctx.replace(content=urls) - triggers = self[ListType.DENY].filter_list_result(new_ctx) + triggers = await self[ListType.DENY].filter_list_result(new_ctx) ctx.notification_domain = new_ctx.notification_domain actions = None messages = [] diff --git a/bot/exts/filtering/_filter_lists/extension.py b/bot/exts/filtering/_filter_lists/extension.py index 3f9d2b287..a53520bf7 100644 --- a/bot/exts/filtering/_filter_lists/extension.py +++ b/bot/exts/filtering/_filter_lists/extension.py @@ -76,7 +76,9 @@ class ExtensionsList(FilterList[ExtensionFilter]): (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.message.attachments } new_ctx = ctx.replace(content={ext for ext, _ in all_ext}) # And prepare the context for the filters to read. - triggered = [filter_ for filter_ in self[ListType.ALLOW].filters.values() if filter_.triggered_on(new_ctx)] + triggered = [ + filter_ for filter_ in self[ListType.ALLOW].filters.values() if await filter_.triggered_on(new_ctx) + ] allowed_ext = {filter_.content for filter_ in triggered} # Get the extensions in the message that are allowed. # See if there are any extensions left which aren't allowed. diff --git a/bot/exts/filtering/_filter_lists/filter_list.py b/bot/exts/filtering/_filter_lists/filter_list.py index f9db54a21..938766aca 100644 --- a/bot/exts/filtering/_filter_lists/filter_list.py +++ b/bot/exts/filtering/_filter_lists/filter_list.py @@ -65,7 +65,7 @@ class AtomicList: """Provide a short description identifying the list with its name and type.""" return f"{past_tense(self.list_type.name.lower())} {self.name.lower()}" - def filter_list_result(self, ctx: FilterContext) -> list[Filter]: + async def filter_list_result(self, ctx: FilterContext) -> list[Filter]: """ Sift through the list of filters, and return only the ones which apply to the given context. @@ -79,10 +79,12 @@ class AtomicList: If the filter is relevant in context, see if it actually triggers. """ - return self._create_filter_list_result(ctx, self.defaults, self.filters.values()) + return await self._create_filter_list_result(ctx, self.defaults, self.filters.values()) @staticmethod - def _create_filter_list_result(ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter]) -> list[Filter]: + async def _create_filter_list_result( + ctx: FilterContext, defaults: Defaults, filters: Iterable[Filter] + ) -> list[Filter]: """A helper function to evaluate the result of `filter_list_result`.""" passed_by_default, failed_by_default = defaults.validations.evaluate(ctx) default_answer = not bool(failed_by_default) @@ -90,12 +92,12 @@ class AtomicList: relevant_filters = [] for filter_ in filters: if not filter_.validations: - if default_answer and filter_.triggered_on(ctx): + if default_answer and await filter_.triggered_on(ctx): relevant_filters.append(filter_) else: passed, failed = filter_.validations.evaluate(ctx) if not failed and failed_by_default < passed: - if filter_.triggered_on(ctx): + if await filter_.triggered_on(ctx): relevant_filters.append(filter_) return relevant_filters @@ -222,10 +224,10 @@ class SubscribingAtomicList(AtomicList): if filter_ not in self.subscriptions[event]: self.subscriptions[event].append(filter_.id) - def filter_list_result(self, ctx: FilterContext) -> list[Filter]: + async def filter_list_result(self, ctx: FilterContext) -> list[Filter]: """Sift through the list of filters, and return only the ones which apply to the given context.""" event_filters = [self.filters[id_] for id_ in self.subscriptions[ctx.event]] - return self._create_filter_list_result(ctx, self.defaults, event_filters) + return await self._create_filter_list_result(ctx, self.defaults, event_filters) class UniquesListBase(FilterList[UniqueFilter]): diff --git a/bot/exts/filtering/_filter_lists/invite.py b/bot/exts/filtering/_filter_lists/invite.py index 0b84aec0e..911b951dd 100644 --- a/bot/exts/filtering/_filter_lists/invite.py +++ b/bot/exts/filtering/_filter_lists/invite.py @@ -81,7 +81,7 @@ class InviteList(FilterList[InviteFilter]): # Find any blocked invites new_ctx = ctx.replace(content={invite.guild.id for invite in invites_for_inspection.values()}) - triggered = self[ListType.DENY].filter_list_result(new_ctx) + triggered = await self[ListType.DENY].filter_list_result(new_ctx) blocked_guilds = {filter_.content for filter_ in triggered} blocked_invites = { code: invite for code, invite in invites_for_inspection.items() if invite.guild.id in blocked_guilds @@ -100,7 +100,8 @@ class InviteList(FilterList[InviteFilter]): if check_if_allowed: # Whether unknown invites need to be checked. new_ctx = ctx.replace(content=guilds_for_inspection) allowed = { - filter_.content for filter_ in self[ListType.ALLOW].filters.values() if filter_.triggered_on(new_ctx) + filter_.content for filter_ in self[ListType.ALLOW].filters.values() + if await filter_.triggered_on(new_ctx) } unknown_invites.update({ code: invite for code, invite in invites_for_inspection.items() if invite.guild.id not in allowed diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py index c7d7cb444..274dc5ea7 100644 --- a/bot/exts/filtering/_filter_lists/token.py +++ b/bot/exts/filtering/_filter_lists/token.py @@ -53,7 +53,7 @@ class TokensList(FilterList[TokenFilter]): text = clean_input(text) ctx = ctx.replace(content=text) - triggers = self[ListType.DENY].filter_list_result(ctx) + triggers = await self[ListType.DENY].filter_list_result(ctx) actions = None messages = [] if triggers: diff --git a/bot/exts/filtering/_filter_lists/unique.py b/bot/exts/filtering/_filter_lists/unique.py index 5204065f9..ecc49af87 100644 --- a/bot/exts/filtering/_filter_lists/unique.py +++ b/bot/exts/filtering/_filter_lists/unique.py @@ -31,7 +31,7 @@ class UniquesList(UniquesListBase): async def actions_for(self, ctx: FilterContext) -> tuple[ActionSettings | None, list[str]]: """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" - triggers = self[ListType.DENY].filter_list_result(ctx) + triggers = await self[ListType.DENY].filter_list_result(ctx) actions = None messages = [] if triggers: diff --git a/bot/exts/filtering/_filters/antispam/attachments.py b/bot/exts/filtering/_filters/antispam/attachments.py new file mode 100644 index 000000000..216d9b886 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/attachments.py @@ -0,0 +1,43 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraAttachmentsSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of attachments before the filter is triggered." + + interval: int = 10 + threshold: int = 6 + + +class AttachmentsFilter(UniqueFilter): + """Detects too many attachments sent by a single user.""" + + name = "attachments" + events = (Event.MESSAGE,) + extra_fields_type = ExtraAttachmentsSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author and len(msg.attachments) > 0} + total_recent_attachments = sum(len(msg.attachments) for msg in detected_messages) + + if total_recent_attachments > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_recent_attachments} attachments" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/burst.py b/bot/exts/filtering/_filters/antispam/burst.py new file mode 100644 index 000000000..d78107d0a --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/burst.py @@ -0,0 +1,41 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraBurstSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of messages before the filter is triggered." + + interval: int = 10 + threshold: int = 7 + + +class BurstFilter(UniqueFilter): + """Detects too many messages sent by a single user.""" + + name = "burst" + events = (Event.MESSAGE,) + extra_fields_type = ExtraBurstSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + if len(detected_messages) > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {len(detected_messages)} messages" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/chars.py b/bot/exts/filtering/_filters/antispam/chars.py new file mode 100644 index 000000000..5c4fa201c --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/chars.py @@ -0,0 +1,43 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraCharsSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of characters before the filter is triggered." + + interval: int = 5 + threshold: int = 4_200 + + +class CharsFilter(UniqueFilter): + """Detects too many characters sent by a single user.""" + + name = "chars" + events = (Event.MESSAGE,) + extra_fields_type = ExtraCharsSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + total_recent_chars = sum(len(msg.content) for msg in relevant_messages) + + if total_recent_chars > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_recent_chars} characters" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/duplicates.py b/bot/exts/filtering/_filters/antispam/duplicates.py index 5df2bb5c0..60d5c322c 100644 --- a/bot/exts/filtering/_filters/antispam/duplicates.py +++ b/bot/exts/filtering/_filters/antispam/duplicates.py @@ -15,7 +15,7 @@ class ExtraDuplicatesSettings(BaseModel): interval_description: ClassVar[str] = ( "Look for rule violations in messages from the last `interval` number of seconds." ) - threshold_description: ClassVar[str] = "Number of duplicate messages required to trigger the filter." + threshold_description: ClassVar[str] = "Maximum number of duplicate messages before the filter is triggered." interval: int = 10 threshold: int = 3 @@ -28,7 +28,7 @@ class DuplicatesFilter(UniqueFilter): events = (Event.MESSAGE,) extra_fields_type = ExtraDuplicatesSettings - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Search for the filter's content within a given context.""" earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) diff --git a/bot/exts/filtering/_filters/antispam/emoji.py b/bot/exts/filtering/_filters/antispam/emoji.py new file mode 100644 index 000000000..0511e4a7b --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/emoji.py @@ -0,0 +1,53 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from emoji import demojize +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +DISCORD_EMOJI_RE = re.compile(r"<:\w+:\d+>|:\w+:") +CODE_BLOCK_RE = re.compile(r"```.*?```", flags=re.DOTALL) + + +class ExtraEmojiSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of emojis before the filter is triggered." + + interval: int = 10 + threshold: int = 20 + + +class EmojiFilter(UniqueFilter): + """Detects too many emojis sent by a single user.""" + + name = "emoji" + events = (Event.MESSAGE,) + extra_fields_type = ExtraEmojiSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + + # Get rid of code blocks in the message before searching for emojis. + # Convert Unicode emojis to :emoji: format to get their count. + total_emojis = sum( + len(DISCORD_EMOJI_RE.findall(demojize(CODE_BLOCK_RE.sub("", msg.content)))) + for msg in relevant_messages + ) + + if total_emojis > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_emojis} emojis" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/links.py b/bot/exts/filtering/_filters/antispam/links.py new file mode 100644 index 000000000..76fe53e70 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/links.py @@ -0,0 +1,52 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +LINK_RE = re.compile(r"(https?://\S+)") + + +class ExtraLinksSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of links before the filter is triggered." + + interval: int = 10 + threshold: int = 10 + + +class DuplicatesFilter(UniqueFilter): + """Detects too many links sent by a single user.""" + + name = "links" + events = (Event.MESSAGE,) + extra_fields_type = ExtraLinksSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + + total_links = 0 + messages_with_links = 0 + for msg in relevant_messages: + total_matches = len(LINK_RE.findall(msg.content)) + if total_matches: + messages_with_links += 1 + total_links += total_matches + + if total_links > self.extra_fields.threshold and messages_with_links > 1: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_links} links" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/mentions.py b/bot/exts/filtering/_filters/antispam/mentions.py new file mode 100644 index 000000000..29a2d5606 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/mentions.py @@ -0,0 +1,90 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from botcore.utils.logging import get_logger +from discord import DeletedReferencedMessage, MessageType, NotFound +from pydantic import BaseModel + +import bot +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +log = get_logger(__name__) + + +class ExtraMentionsSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of distinct mentions before the filter is triggered." + + interval: int = 10 + threshold: int = 5 + + +class DuplicatesFilter(UniqueFilter): + """ + Detects total mentions exceeding the limit sent by a single user. + + Excludes mentions that are bots, themselves, or replied users. + + In very rare cases, may not be able to determine a + mention was to a reply, in which case it is not ignored. + """ + + name = "mentions" + events = (Event.MESSAGE,) + extra_fields_type = ExtraMentionsSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + + # We use `msg.mentions` here as that is supplied by the api itself, to determine who was mentioned. + # Additionally, `msg.mentions` includes the user replied to, even if the mention doesn't occur in the body. + # In order to exclude users who are mentioned as a reply, we check if the msg has a reference + # + # While we could use regex to parse the message content, and get a list of + # the mentions, that solution is very prone to breaking. + # We would need to deal with codeblocks, escaping markdown, and any discrepancies between + # our implementation and discord's Markdown parser which would cause false positives or false negatives. + total_recent_mentions = 0 + for msg in relevant_messages: + # We check if the message is a reply, and if it is try to get the author + # since we ignore mentions of a user that we're replying to + reply_author = None + + if msg.type == MessageType.reply: + ref = msg.reference + + if not (resolved := ref.resolved): + # It is possible, in a very unusual situation, for a message to have a reference + # that is both not in the cache, and deleted while running this function. + # In such a situation, this will throw an error which we catch. + try: + resolved = await bot.instance.get_partial_messageable(resolved.channel_id).fetch_message( + resolved.message_id + ) + except NotFound: + log.info('Could not fetch the reference message as it has been deleted.') + + if resolved and not isinstance(resolved, DeletedReferencedMessage): + reply_author = resolved.author + + for user in msg.mentions: + # Don't count bot or self mentions, or the user being replied to (if applicable) + if user.bot or user in {msg.author, reply_author}: + continue + total_recent_mentions += 1 + + if total_recent_mentions > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_recent_mentions} mentions" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/newlines.py b/bot/exts/filtering/_filters/antispam/newlines.py new file mode 100644 index 000000000..b15a35219 --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/newlines.py @@ -0,0 +1,61 @@ +import re +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + +NEWLINES = re.compile(r"(\n+)") + + +class ExtraNewlinesSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of newlines before the filter is triggered." + consecutive_threshold_description: ClassVar[str] = ( + "Maximum number of consecutive newlines before the filter is triggered." + ) + + interval: int = 10 + threshold: int = 100 + consecutive_threshold: int = 10 + + +class NewlinesFilter(UniqueFilter): + """Detects too many newlines sent by a single user.""" + + name = "newlines" + events = (Event.MESSAGE,) + extra_fields_type = ExtraNewlinesSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + + # Identify groups of newline characters and get group & total counts + newline_counts = [] + for msg in relevant_messages: + newline_counts += [len(group) for group in NEWLINES.findall(msg.content)] + total_recent_newlines = sum(newline_counts) + # Get maximum newline group size + max_newline_group = max(newline_counts, default=0) + + # Check first for total newlines, if this passes then check for large groupings + if total_recent_newlines > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_recent_newlines} newlines" + return True + if max_newline_group > self.extra_fields.consecutive_threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {max_newline_group} consecutive newlines" + return True + return False diff --git a/bot/exts/filtering/_filters/antispam/role_mentions.py b/bot/exts/filtering/_filters/antispam/role_mentions.py new file mode 100644 index 000000000..49de642fa --- /dev/null +++ b/bot/exts/filtering/_filters/antispam/role_mentions.py @@ -0,0 +1,42 @@ +from datetime import timedelta +from itertools import takewhile +from typing import ClassVar + +import arrow +from pydantic import BaseModel + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.filter import UniqueFilter + + +class ExtraRoleMentionsSettings(BaseModel): + """Extra settings for when to trigger the antispam rule.""" + + interval_description: ClassVar[str] = ( + "Look for rule violations in messages from the last `interval` number of seconds." + ) + threshold_description: ClassVar[str] = "Maximum number of role mentions before the filter is triggered." + + interval: int = 10 + threshold: int = 3 + + +class DuplicatesFilter(UniqueFilter): + """Detects too many role mentions sent by a single user.""" + + name = "role_mentions" + events = (Event.MESSAGE,) + extra_fields_type = ExtraRoleMentionsSettings + + async def triggered_on(self, ctx: FilterContext) -> bool: + """Search for the filter's content within a given context.""" + earliest_relevant_at = arrow.utcnow() - timedelta(seconds=self.extra_fields.interval) + relevant_messages = list(takewhile(lambda msg: msg.created_at > earliest_relevant_at, ctx.content)) + detected_messages = {msg for msg in relevant_messages if msg.author == ctx.author} + total_recent_mentions = sum(len(msg.role_mentions) for msg in relevant_messages) + + if total_recent_mentions > self.extra_fields.threshold: + ctx.related_messages |= detected_messages + ctx.filter_info[self] = f"sent {total_recent_mentions} role mentions" + return True + return False diff --git a/bot/exts/filtering/_filters/domain.py b/bot/exts/filtering/_filters/domain.py index e22cafbb7..4cc3a6f5a 100644 --- a/bot/exts/filtering/_filters/domain.py +++ b/bot/exts/filtering/_filters/domain.py @@ -34,7 +34,7 @@ class DomainFilter(Filter): name = "domain" extra_fields_type = ExtraDomainSettings - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Searches for a domain within a given context.""" domain = tldextract.extract(self.content).registered_domain diff --git a/bot/exts/filtering/_filters/extension.py b/bot/exts/filtering/_filters/extension.py index 926a6a2fb..f3f64532f 100644 --- a/bot/exts/filtering/_filters/extension.py +++ b/bot/exts/filtering/_filters/extension.py @@ -11,7 +11,7 @@ class ExtensionFilter(Filter): name = "extension" - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Searches for an attachment extension in the context content, given as a set of extensions.""" return self.content in ctx.content diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index b0d19d3a8..4ae7ec45f 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any from pydantic import ValidationError @@ -48,7 +48,7 @@ class Filter(FieldRequiring): return settings, filter_settings @abstractmethod - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Search for the filter's content within a given context.""" @classmethod @@ -81,7 +81,7 @@ class Filter(FieldRequiring): return string -class UniqueFilter(Filter, ABC): +class UniqueFilter(Filter): """ Unique filters are ones that should only be run once in a given context. diff --git a/bot/exts/filtering/_filters/invite.py b/bot/exts/filtering/_filters/invite.py index ac4f62cb6..e8f3e9851 100644 --- a/bot/exts/filtering/_filters/invite.py +++ b/bot/exts/filtering/_filters/invite.py @@ -20,7 +20,7 @@ class InviteFilter(Filter): super().__init__(filter_data, defaults_data) self.content = int(self.content) - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Searches for a guild ID in the context content, given as a set of IDs.""" return self.content in ctx.content diff --git a/bot/exts/filtering/_filters/token.py b/bot/exts/filtering/_filters/token.py index 04e30cb03..f61d38846 100644 --- a/bot/exts/filtering/_filters/token.py +++ b/bot/exts/filtering/_filters/token.py @@ -11,7 +11,7 @@ class TokenFilter(Filter): name = "token" - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Searches for a regex pattern within a given context.""" pattern = self.content diff --git a/bot/exts/filtering/_filters/unique/discord_token.py b/bot/exts/filtering/_filters/unique/discord_token.py index 7fdb800df..731df198c 100644 --- a/bot/exts/filtering/_filters/unique/discord_token.py +++ b/bot/exts/filtering/_filters/unique/discord_token.py @@ -69,7 +69,7 @@ class DiscordTokenFilter(UniqueFilter): """Get currently loaded ModLog cog instance.""" return bot.instance.get_cog("ModLog") - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Return whether the message contains Discord client tokens.""" found_token = self.find_token_in_message(ctx.content) if not found_token: diff --git a/bot/exts/filtering/_filters/unique/everyone.py b/bot/exts/filtering/_filters/unique/everyone.py index 06d3a19bb..a32e67cc5 100644 --- a/bot/exts/filtering/_filters/unique/everyone.py +++ b/bot/exts/filtering/_filters/unique/everyone.py @@ -18,7 +18,7 @@ class EveryoneFilter(UniqueFilter): name = "everyone" events = (Event.MESSAGE, Event.MESSAGE_EDIT) - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Search for the filter's content within a given context.""" # First pass to avoid running re.sub on every message if not EVERYONE_PING_RE.search(ctx.content): diff --git a/bot/exts/filtering/_filters/unique/rich_embed.py b/bot/exts/filtering/_filters/unique/rich_embed.py index a0d9e263f..09d513373 100644 --- a/bot/exts/filtering/_filters/unique/rich_embed.py +++ b/bot/exts/filtering/_filters/unique/rich_embed.py @@ -17,7 +17,7 @@ class RichEmbedFilter(UniqueFilter): name = "rich_embed" events = (Event.MESSAGE, Event.MESSAGE_EDIT) - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Determine if `msg` contains any rich embeds not auto-generated from a URL.""" if ctx.embeds: for embed in ctx.embeds: diff --git a/bot/exts/filtering/_filters/unique/webhook.py b/bot/exts/filtering/_filters/unique/webhook.py index b9d98db35..16ff1b213 100644 --- a/bot/exts/filtering/_filters/unique/webhook.py +++ b/bot/exts/filtering/_filters/unique/webhook.py @@ -29,7 +29,7 @@ class WebhookFilter(UniqueFilter): """Get current instance of `ModLog`.""" return bot.instance.get_cog("ModLog") - def triggered_on(self, ctx: FilterContext) -> bool: + async def triggered_on(self, ctx: FilterContext) -> bool: """Search for a webhook in the given content. If found, attempt to delete it.""" matches = set(WEBHOOK_URL_RE.finditer(ctx.content)) if not matches: diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 514ef39e1..aad36af14 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -175,7 +175,6 @@ class Filtering(Cog): self.message_cache.append(msg) ctx = FilterContext(Event.MESSAGE, msg.author, msg.channel, msg.content, msg, msg.embeds) - result_actions, list_messages = await self._resolve_action(ctx) if result_actions: await result_actions.action(ctx) @@ -194,7 +193,7 @@ class Filtering(Cog): @blocklist.command(name="list", aliases=("get",)) async def bl_list(self, ctx: Context, list_name: Optional[str] = None) -> None: """List the contents of a specified blacklist.""" - result = self._resolve_list_type_and_name(ctx, ListType.DENY, list_name) + result = await self._resolve_list_type_and_name(ctx, ListType.DENY, list_name) if not result: return list_type, filter_list = result @@ -237,7 +236,7 @@ class Filtering(Cog): @allowlist.command(name="list", aliases=("get",)) async def al_list(self, ctx: Context, list_name: Optional[str] = None) -> None: """List the contents of a specified whitelist.""" - result = self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name) + result = await self._resolve_list_type_and_name(ctx, ListType.ALLOW, list_name) if not result: return list_type, filter_list = result diff --git a/tests/bot/exts/filtering/test_filters.py b/tests/bot/exts/filtering/test_filters.py index 214637b52..29b50188a 100644 --- a/tests/bot/exts/filtering/test_filters.py +++ b/tests/bot/exts/filtering/test_filters.py @@ -5,7 +5,7 @@ from bot.exts.filtering._filters.token import TokenFilter from tests.helpers import MockMember, MockMessage, MockTextChannel -class FilterTests(unittest.TestCase): +class FilterTests(unittest.IsolatedAsyncioTestCase): """Test functionality of the token filter.""" def setUp(self) -> None: @@ -14,7 +14,7 @@ class FilterTests(unittest.TestCase): message = MockMessage(author=member, channel=channel) self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) - def test_token_filter_triggers(self): + async def test_token_filter_triggers(self): """The filter should evaluate to True only if its token is found in the context content.""" test_cases = ( (r"hi", "oh hi there", True), @@ -37,5 +37,5 @@ class FilterTests(unittest.TestCase): "additional_field": "{}" # noqa: P103 }) self.ctx.content = content - result = filter_.triggered_on(self.ctx) + result = await filter_.triggered_on(self.ctx) self.assertEqual(result, expected) -- cgit v1.2.3 From 44bf2477d675035bb4024d7cd4fa400d7f2b8942 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sun, 22 Jan 2023 21:21:57 +0200 Subject: Bring back old system tests --- .../exts/filtering/test_discord_token_filter.py | 276 +++++++++++++++++++++ tests/bot/exts/filtering/test_extension_filter.py | 139 +++++++++++ tests/bot/exts/filtering/test_filters.py | 41 --- tests/bot/exts/filtering/test_settings_entries.py | 66 ++++- tests/bot/exts/filtering/test_token_filter.py | 49 ++++ 5 files changed, 522 insertions(+), 49 deletions(-) create mode 100644 tests/bot/exts/filtering/test_discord_token_filter.py create mode 100644 tests/bot/exts/filtering/test_extension_filter.py delete mode 100644 tests/bot/exts/filtering/test_filters.py create mode 100644 tests/bot/exts/filtering/test_token_filter.py (limited to 'tests') diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py new file mode 100644 index 000000000..ef124e6ff --- /dev/null +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -0,0 +1,276 @@ +import unittest +from re import Match +from unittest import mock +from unittest.mock import MagicMock, patch + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.unique import discord_token +from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token +from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec + + +class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): + """Tests the DiscordTokenFilter class.""" + + def setUp(self): + """Adds the filter, a bot, and a message to the instance for usage in tests.""" + now = arrow.utcnow().timestamp() + self.filter = DiscordTokenFilter({ + "id": 1, + "content": "discord_token", + "description": None, + "settings": {}, + "additional_field": "{}", # noqa: P103 + "created_at": now, + "updated_at": now + }) + + self.msg = MockMessage(id=555, content="hello world") + self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name) + + member = MockMember(id=123) + channel = MockTextChannel(id=345) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg) + + def test_extract_user_id_valid(self): + """Should consider user IDs valid if they decode into an integer ID.""" + id_pairs = ( + ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332), + ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904), + ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641), + ) + + for token_id, user_id in id_pairs: + with self.subTest(token_id=token_id): + result = DiscordTokenFilter.extract_user_id(token_id) + self.assertEqual(result, user_id) + + def test_extract_user_id_invalid(self): + """Should consider non-digit and non-ASCII IDs invalid.""" + ids = ( + ("SGVsbG8gd29ybGQ", "non-digit ASCII"), + ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"), + ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"), + ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"), + ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for user_id, msg in ids: + with self.subTest(msg=msg): + result = DiscordTokenFilter.extract_user_id(user_id) + self.assertIsNone(result) + + def test_is_valid_timestamp_valid(self): + """Should consider timestamps valid if they're greater than the Discord epoch.""" + timestamps = ( + "XsyRkw", + "Xrim9Q", + "XsyR-w", + "XsySD_", + "Dn9r_A", + ) + + for timestamp in timestamps: + with self.subTest(timestamp=timestamp): + result = DiscordTokenFilter.is_valid_timestamp(timestamp) + self.assertTrue(result) + + def test_is_valid_timestamp_invalid(self): + """Should consider timestamps invalid if they're before Discord epoch or can't be parsed.""" + timestamps = ( + ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"), + ("ew", "123"), + ("AoIKgA", "42076800"), + ("{hello}[world]&(bye!)", "ASCII invalid Base64"), + ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"), + ) + + for timestamp, msg in timestamps: + with self.subTest(msg=msg): + result = DiscordTokenFilter.is_valid_timestamp(timestamp) + self.assertFalse(result) + + def test_is_valid_hmac_valid(self): + """Should consider an HMAC valid if it has at least 3 unique characters.""" + valid_hmacs = ( + "VXmErH7j511turNpfURmb0rVNm8", + "Ysnu2wacjaKs7qnoo46S8Dm2us8", + "sJf6omBPORBPju3WJEIAcwW9Zds", + "s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for hmac in valid_hmacs: + with self.subTest(msg=hmac): + result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) + self.assertTrue(result) + + def test_is_invalid_hmac_invalid(self): + """Should consider an HMAC invalid if has fewer than 3 unique characters.""" + invalid_hmacs = ( + ("xxxxxxxxxxxxxxxxxx", "Single character"), + ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"), + ("ASFasfASFasfASFASsf", "Three characters alternating-case"), + ("asdasdasdasdasdasdasd", "Three characters one case"), + ) + + for hmac, msg in invalid_hmacs: + with self.subTest(msg=msg): + result = DiscordTokenFilter.is_maybe_valid_hmac(hmac) + self.assertFalse(result) + + async def test_no_trigger_when_no_token(self): + """False should be returned if the message doesn't contain a Discord token.""" + return_value = await self.filter.triggered_on(self.ctx) + + self.assertFalse(return_value) + + @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") + @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") + @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") + def test_find_token_valid_match( + self, + token_re, + token_cls, + extract_user_id, + is_valid_timestamp, + is_maybe_valid_hmac, + ): + """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`.""" + matches = [ + mock.create_autospec(Match, spec_set=True, instance=True), + mock.create_autospec(Match, spec_set=True, instance=True), + ] + tokens = [ + mock.create_autospec(Token, spec_set=True, instance=True), + mock.create_autospec(Token, spec_set=True, instance=True), + ] + + token_re.finditer.return_value = matches + token_cls.side_effect = tokens + extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid. + is_valid_timestamp.return_value = True + is_maybe_valid_hmac.return_value = True + + return_value = DiscordTokenFilter.find_token_in_message(self.msg) + + self.assertEqual(tokens[1], return_value) + + @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac") + @autospec("bot.exts.filtering._filters.unique.discord_token", "Token") + @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE") + def test_find_token_invalid_matches( + self, + token_re, + token_cls, + extract_user_id, + is_valid_timestamp, + is_maybe_valid_hmac, + ): + """None should be returned if no matches have valid user IDs, HMACs, and timestamps.""" + token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)] + token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True) + extract_user_id.return_value = None + is_valid_timestamp.return_value = False + is_maybe_valid_hmac.return_value = False + + return_value = DiscordTokenFilter.find_token_in_message(self.msg) + + self.assertIsNone(return_value) + + def test_regex_invalid_tokens(self): + """Messages without anything looking like a token are not matched.""" + tokens = ( + "", + "lemon wins", + "..", + "x.y", + "x.y.", + ".y.z", + ".y.", + "..z", + "x..z", + " . . ", + "\n.\n.\n", + "hellö.world.bye", + "base64.nötbåse64.morebase64", + "19jd3J.dfkm3d.€víł§tüff", + ) + + for token in tokens: + with self.subTest(token=token): + results = discord_token.TOKEN_RE.findall(token) + self.assertEqual(len(results), 0) + + def test_regex_valid_tokens(self): + """Messages that look like tokens should be matched.""" + # Don't worry, these tokens have been invalidated. + tokens = ( + "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8", + "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8", + "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds", + "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4", + ) + + for token in tokens: + with self.subTest(token=token): + results = discord_token.TOKEN_RE.fullmatch(token) + self.assertIsNotNone(results, f"{token} was not matched by the regex") + + def test_regex_matches_multiple_valid(self): + """Should support multiple matches in the middle of a string.""" + token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8" + token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc" + message = f"garbage {token_1} hello {token_2} world" + + results = discord_token.TOKEN_RE.finditer(message) + results = [match[0] for match in results] + self.assertCountEqual((token_1, token_2), results) + + @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE") + def test_format_log_message(self, log_message): + """Should correctly format the log message with info from the message and token.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + log_message.format.return_value = "Howdy" + + return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token) + + self.assertEqual(return_value, log_message.format.return_value) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE") + @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member") + async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message): + """Should correctly format the user ID portion when the actual user it belongs to is unknown.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + unknown_user_log_message.format.return_value = " Partner" + get_or_fetch_member.return_value = None + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False)) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") + async def test_format_userid_log_message_bot(self, known_user_log_message): + """Should correctly format the user ID portion when the ID belongs to a known bot.""" + token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + known_user_log_message.format.return_value = " Partner" + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (known_user_log_message.format.return_value, True)) + + @patch("bot.instance", MockBot()) + @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE") + async def test_format_log_message_user_token_user(self, user_token_message): + """Should correctly format the user ID portion when the ID belongs to a known user.""" + token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4") + user_token_message.format.return_value = "Partner" + + return_value = await DiscordTokenFilter.format_userid_log_message(token) + + self.assertEqual(return_value, (user_token_message.format.return_value, True)) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py new file mode 100644 index 000000000..0ad41116d --- /dev/null +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, patch + +import arrow + +from bot.constants import Channels +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filter_lists import extension +from bot.exts.filtering._filter_lists.extension import ExtensionsList +from bot.exts.filtering._filter_lists.filter_list import ListType +from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel + +BOT = MockBot() + + +class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): + """Test the ExtensionsList class.""" + + def setUp(self): + """Sets up fresh objects for each test.""" + self.filter_list = ExtensionsList(MagicMock()) + now = arrow.utcnow().timestamp() + filters = [] + self.whitelist = [".first", ".second", ".third"] + for i, filter_content in enumerate(self.whitelist, start=1): + filters.append({ + "id": i, "content": filter_content, "description": None, "settings": {}, + "additional_field": "{}", "created_at": now, "updated_at": now # noqa: P103 + }) + self.filter_list.add_list({ + "id": 1, + "list_type": 1, + "created_at": now, + "updated_at": now, + "settings": {}, + "filters": filters + }) + + self.message = MockMessage() + member = MockMember(id=123) + channel = MockTextChannel(id=345) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message) + + @patch("bot.instance", BOT) + async def test_message_with_allowed_attachment(self): + """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" + attachment = MockAttachment(filename="python.first") + self.message.attachments = [attachment] + + result = await self.filter_list.actions_for(self.ctx) + + self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) + + @patch("bot.instance", BOT) + async def test_message_without_attachment(self): + """Messages without attachments should return no triggers, messages, or actions.""" + result = await self.filter_list.actions_for(self.ctx) + + self.assertEqual(result, (None, [], {})) + + @patch("bot.instance", BOT) + async def test_message_with_illegal_extension(self): + """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + result = await self.filter_list.actions_for(self.ctx) + + self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) + + @patch("bot.instance", BOT) + async def test_python_file_redirect_embed_description(self): + """A message containing a .py file should result in an embed redirecting the user to our paste site.""" + attachment = MockAttachment(filename="python.py") + self.message.attachments = [attachment] + + await self.filter_list.actions_for(self.ctx) + + self.assertEqual(self.ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + + @patch("bot.instance", BOT) + async def test_txt_file_redirect_embed_description(self): + """A message containing a .txt/.json/.csv file should result in the correct embed.""" + test_values = ( + ("text", ".txt"), + ("json", ".json"), + ("csv", ".csv"), + ) + + for file_name, disallowed_extension in test_values: + with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): + + attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") + self.message.attachments = [attachment] + + await self.filter_list.actions_for(self.ctx) + + self.assertEqual( + self.ctx.dm_embed, + extension.TXT_EMBED_DESCRIPTION.format( + blocked_extension=disallowed_extension, + ) + ) + + @patch("bot.instance", BOT) + async def test_other_disallowed_extension_embed_description(self): + """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.attachments = [attachment] + + await self.filter_list.actions_for(self.ctx) + meta_channel = BOT.get_channel(Channels.meta) + + self.assertEqual( + self.ctx.dm_embed, + extension.DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=", ".join(self.whitelist), + blocked_extensions_str=".disallowed", + meta_channel_mention=meta_channel.mention + ) + ) + + @patch("bot.instance", BOT) + async def test_get_disallowed_extensions(self): + """The return value should include all non-whitelisted extensions.""" + test_values = ( + ([], []), + (self.whitelist, []), + ([".first"], []), + ([".first", ".disallowed"], ["`.disallowed`"]), + ([".disallowed"], ["`.disallowed`"]), + ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]), + ) + + for extensions, expected_disallowed_extensions in test_values: + with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): + self.message.attachments = [MockAttachment(filename=f"filename{ext}") for ext in extensions] + result = await self.filter_list.actions_for(self.ctx) + self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/filtering/test_filters.py b/tests/bot/exts/filtering/test_filters.py deleted file mode 100644 index 29b50188a..000000000 --- a/tests/bot/exts/filtering/test_filters.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest - -from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._filters.token import TokenFilter -from tests.helpers import MockMember, MockMessage, MockTextChannel - - -class FilterTests(unittest.IsolatedAsyncioTestCase): - """Test functionality of the token filter.""" - - def setUp(self) -> None: - member = MockMember(id=123) - channel = MockTextChannel(id=345) - message = MockMessage(author=member, channel=channel) - self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) - - async def test_token_filter_triggers(self): - """The filter should evaluate to True only if its token is found in the context content.""" - test_cases = ( - (r"hi", "oh hi there", True), - (r"hi", "goodbye", False), - (r"bla\d{2,4}", "bla18", True), - (r"bla\d{2,4}", "bla1", False) - ) - - for pattern, content, expected in test_cases: - with self.subTest( - pattern=pattern, - content=content, - expected=expected, - ): - filter_ = TokenFilter({ - "id": 1, - "content": pattern, - "description": None, - "settings": {}, - "additional_field": "{}" # noqa: P103 - }) - self.ctx.content = content - result = await filter_.triggered_on(self.ctx) - self.assertEqual(result, expected) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index 34b155d6b..5a1eb6fe6 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -50,7 +50,9 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_for_empty_channel_scope(self): """A filter is enabled for all channels by default.""" channel = MockTextChannel() - scope = ChannelScope(disabled_channels=None, disabled_categories=None, enabled_channels=None) + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None + ) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -60,7 +62,9 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_for_disabled_channel(self): """A filter shouldn't trigger if it's been disabled in the channel.""" channel = MockTextChannel(id=123) - scope = ChannelScope(disabled_channels=["123"], disabled_categories=None, enabled_channels=None) + scope = ChannelScope( + disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None + ) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -70,7 +74,9 @@ class FilterTests(unittest.TestCase): def test_context_doesnt_trigger_in_disabled_category(self): """A filter shouldn't trigger if it's been disabled in the category.""" channel = MockTextChannel(category=MockCategoryChannel(id=456)) - scope = ChannelScope(disabled_channels=None, disabled_categories=["456"], enabled_channels=None) + scope = ChannelScope( + disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None + ) self.ctx.channel = channel result = scope.triggers_on(self.ctx) @@ -80,13 +86,51 @@ class FilterTests(unittest.TestCase): def test_context_triggers_in_enabled_channel_in_disabled_category(self): """A filter should trigger in an enabled channel even if it's been disabled in the category.""" channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) - scope = ChannelScope(disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"]) + scope = ChannelScope( + disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertTrue(result) + + def test_context_triggers_inside_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"] + ) self.ctx.channel = channel result = scope.triggers_on(self.ctx) self.assertTrue(result) + def test_context_doesnt_trigger_outside_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"] + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + + def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self): + """A filter shouldn't trigger outside enabled categories, if there are any.""" + channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234)) + scope = ChannelScope( + disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"] + ) + self.ctx.channel = channel + + result = scope.triggers_on(self.ctx) + + self.assertFalse(result) + def test_filtering_dms_when_necessary(self): """A filter correctly ignores or triggers in a channel depending on the value of FilterDM.""" cases = ( @@ -112,14 +156,16 @@ class FilterTests(unittest.TestCase): infraction_reason="hi", infraction_duration=10, dm_content="how", - dm_embed="what is" + dm_embed="what is", + infraction_channel=0 ) infraction2 = InfractionAndNotification( infraction_type="MUTE", infraction_reason="there", infraction_duration=20, dm_content="are you", - dm_embed="your name" + dm_embed="your name", + infraction_channel=0 ) result = infraction1 | infraction2 @@ -132,6 +178,7 @@ class FilterTests(unittest.TestCase): "infraction_duration": 20.0, "dm_content": "are you", "dm_embed": "your name", + "infraction_channel": 0 } ) @@ -142,14 +189,16 @@ class FilterTests(unittest.TestCase): infraction_reason="hi", infraction_duration=20, dm_content="", - dm_embed="" + dm_embed="", + infraction_channel=0 ) infraction2 = InfractionAndNotification( infraction_type="BAN", infraction_reason="", infraction_duration=10, dm_content="there", - dm_embed="" + dm_embed="", + infraction_channel=0 ) result = infraction1 | infraction2 @@ -162,5 +211,6 @@ class FilterTests(unittest.TestCase): "infraction_duration": 10.0, "dm_content": "there", "dm_embed": "", + "infraction_channel": 0 } ) diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py new file mode 100644 index 000000000..0dfc8ae9f --- /dev/null +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -0,0 +1,49 @@ +import unittest + +import arrow + +from bot.exts.filtering._filter_context import Event, FilterContext +from bot.exts.filtering._filters.token import TokenFilter +from tests.helpers import MockMember, MockMessage, MockTextChannel + + +class TokenFilterTests(unittest.IsolatedAsyncioTestCase): + """Test functionality of the token filter.""" + + def setUp(self) -> None: + member = MockMember(id=123) + channel = MockTextChannel(id=345) + message = MockMessage(author=member, channel=channel) + self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message) + + async def test_token_filter_triggers(self): + """The filter should evaluate to True only if its token is found in the context content.""" + test_cases = ( + (r"hi", "oh hi there", True), + (r"hi", "goodbye", False), + (r"bla\d{2,4}", "bla18", True), + (r"bla\d{2,4}", "bla1", False), + # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6 + (r"TOKEN", "https://google.com TOKEN", True), + (r"TOKEN", "https://google.com something else", False) + ) + now = arrow.utcnow().timestamp() + + for pattern, content, expected in test_cases: + with self.subTest( + pattern=pattern, + content=content, + expected=expected, + ): + filter_ = TokenFilter({ + "id": 1, + "content": pattern, + "description": None, + "settings": {}, + "additional_field": "{}", # noqa: P103 + "created_at": now, + "updated_at": now + }) + self.ctx.content = content + result = await filter_.triggered_on(self.ctx) + self.assertEqual(result, expected) -- cgit v1.2.3 From b64eee91c0e212e53307fbcdd72fe8794dd48d9d Mon Sep 17 00:00:00 2001 From: mbaruh Date: Thu, 23 Mar 2023 19:51:06 +0200 Subject: Fix filtering tests --- tests/bot/exts/filtering/test_settings_entries.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index 5a1eb6fe6..c5f0152b0 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -152,7 +152,7 @@ class FilterTests(unittest.TestCase): def test_infraction_merge_of_same_infraction_type(self): """When both infractions are of the same type, the one with the longer duration wins.""" infraction1 = InfractionAndNotification( - infraction_type="MUTE", + infraction_type="TIMEOUT", infraction_reason="hi", infraction_duration=10, dm_content="how", @@ -160,7 +160,7 @@ class FilterTests(unittest.TestCase): infraction_channel=0 ) infraction2 = InfractionAndNotification( - infraction_type="MUTE", + infraction_type="TIMEOUT", infraction_reason="there", infraction_duration=20, dm_content="are you", @@ -168,12 +168,12 @@ class FilterTests(unittest.TestCase): infraction_channel=0 ) - result = infraction1 | infraction2 + result = infraction1.union(infraction2) self.assertDictEqual( result.dict(), { - "infraction_type": Infraction.MUTE, + "infraction_type": Infraction.TIMEOUT, "infraction_reason": "there", "infraction_duration": 20.0, "dm_content": "are you", @@ -185,7 +185,7 @@ class FilterTests(unittest.TestCase): def test_infraction_merge_of_different_infraction_types(self): """If there are two different infraction types, the one higher up the hierarchy should be picked.""" infraction1 = InfractionAndNotification( - infraction_type="MUTE", + infraction_type="TIMEOUT", infraction_reason="hi", infraction_duration=20, dm_content="", @@ -201,7 +201,7 @@ class FilterTests(unittest.TestCase): infraction_channel=0 ) - result = infraction1 | infraction2 + result = infraction1.union(infraction2) self.assertDictEqual( result.dict(), -- cgit v1.2.3 From 509c7968dab875f8e3e7934647c757a2a73f724b Mon Sep 17 00:00:00 2001 From: mbaruh Date: Thu, 23 Mar 2023 19:59:05 +0200 Subject: Add support for snekbox IO in the new filtering system --- bot/exts/filtering/_filter_context.py | 19 ++++- bot/exts/filtering/_filter_lists/antispam.py | 4 +- bot/exts/filtering/_filter_lists/domain.py | 2 +- bot/exts/filtering/_filter_lists/extension.py | 56 ++++++++------- bot/exts/filtering/_filter_lists/invite.py | 2 +- bot/exts/filtering/_filter_lists/token.py | 2 +- .../filtering/_filters/unique/discord_token.py | 2 +- bot/exts/filtering/_filters/unique/everyone.py | 2 +- bot/exts/filtering/_filters/unique/webhook.py | 2 +- .../_settings_types/actions/remove_context.py | 4 +- bot/exts/filtering/_ui/ui.py | 6 +- bot/exts/filtering/filtering.py | 17 +++-- bot/exts/utils/snekbox/_cog.py | 84 ++++++++++------------ bot/exts/utils/snekbox/_io.py | 10 +-- tests/bot/exts/filtering/test_extension_filter.py | 30 ++++---- tests/bot/exts/utils/snekbox/test_snekbox.py | 8 +-- 16 files changed, 134 insertions(+), 116 deletions(-) (limited to 'tests') diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py index 8e1ed5788..483706e2a 100644 --- a/bot/exts/filtering/_filter_context.py +++ b/bot/exts/filtering/_filter_context.py @@ -5,12 +5,14 @@ from collections.abc import Callable, Coroutine, Iterable from dataclasses import dataclass, field, replace from enum import Enum, auto +import discord from discord import DMChannel, Embed, Member, Message, TextChannel, Thread, User from bot.utils.message_cache import MessageCache if typing.TYPE_CHECKING: from bot.exts.filtering._filters.filter import Filter + from bot.exts.utils.snekbox._io import FileAttachment class Event(Enum): @@ -19,6 +21,7 @@ class Event(Enum): MESSAGE = auto() MESSAGE_EDIT = auto() NICKNAME = auto() + SNEKBOX = auto() @dataclass @@ -32,6 +35,7 @@ class FilterContext: content: str | Iterable # What actually needs filtering. The Iterable type depends on the filter list. message: Message | None # The message involved embeds: list[Embed] = field(default_factory=list) # Any embeds involved + attachments: list[discord.Attachment | FileAttachment] = field(default_factory=list) # Any attachments sent. before_message: Message | None = None message_cache: MessageCache | None = None # Output context @@ -45,11 +49,12 @@ class FilterContext: notification_domain: str = "" # A domain to send the user for context filter_info: dict['Filter', str] = field(default_factory=dict) # Additional info from a filter. messages_deletion: bool = False # Whether the messages were deleted. Can't upload deletion log otherwise. + blocked_exts: set[str] = field(default_factory=set) # Any extensions blocked (used for snekbox) # Additional actions to perform additional_actions: list[Callable[[FilterContext], Coroutine]] = field(default_factory=list) related_messages: set[Message] = field(default_factory=set) # Deletion will include these. related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set) - attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs. + uploaded_attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs. upload_deletion_logs: bool = True # Whether it's allowed to upload deletion logs. @classmethod @@ -57,7 +62,17 @@ class FilterContext: cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None ) -> FilterContext: """Create a filtering context from the attributes of a message.""" - return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache) + return cls( + event, + message.author, + message.channel, + message.content, + message, + message.embeds, + message.attachments, + before, + cache + ) def replace(self, **changes) -> FilterContext: """Return a new context object assigning new values to the specified fields.""" diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py index 0e7ab2bdc..ba20051fc 100644 --- a/bot/exts/filtering/_filter_lists/antispam.py +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -171,7 +171,9 @@ class DeletionContext: new_ctx.related_channels = reduce( or_, (other_ctx.related_channels for other_ctx in other_contexts), ctx.related_channels ) | {ctx.channel for ctx in other_contexts} - new_ctx.attachments = reduce(or_, (other_ctx.attachments for other_ctx in other_contexts), ctx.attachments) + new_ctx.uploaded_attachments = reduce( + or_, (other_ctx.uploaded_attachments for other_ctx in other_contexts), ctx.uploaded_attachments + ) new_ctx.upload_deletion_logs = True new_ctx.messages_deletion = all(ctx.messages_deletion for ctx in self.contexts) diff --git a/bot/exts/filtering/_filter_lists/domain.py b/bot/exts/filtering/_filter_lists/domain.py index f4062edfe..091fd14e0 100644 --- a/bot/exts/filtering/_filter_lists/domain.py +++ b/bot/exts/filtering/_filter_lists/domain.py @@ -31,7 +31,7 @@ class DomainsList(FilterList[DomainFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filter_lists/extension.py b/bot/exts/filtering/_filter_lists/extension.py index a739d7191..868fde2b2 100644 --- a/bot/exts/filtering/_filter_lists/extension.py +++ b/bot/exts/filtering/_filter_lists/extension.py @@ -49,7 +49,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE) + filtering_cog.subscribe(self, Event.MESSAGE, Event.SNEKBOX) self._whitelisted_description = None def get_filter_type(self, content: str) -> type[Filter]: @@ -66,7 +66,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" # Return early if the message doesn't have attachments. - if not ctx.message or not ctx.message.attachments: + if not ctx.message or not ctx.attachments: return None, [], {} _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx) @@ -75,7 +75,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): # Find all extensions in the message. all_ext = { - (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.message.attachments + (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.attachments } new_ctx = ctx.replace(content={ext for ext, _ in all_ext}) # And prepare the context for the filters to read. triggered = [ @@ -86,31 +86,37 @@ class ExtensionsList(FilterList[ExtensionFilter]): # See if there are any extensions left which aren't allowed. not_allowed = {ext: filename for ext, filename in all_ext if ext not in allowed_ext} + if ctx.event == Event.SNEKBOX: + not_allowed = {ext: filename for ext, filename in not_allowed.items() if ext not in TXT_LIKE_FILES} + if not not_allowed: # Yes, it's a double negative. Meaning all attachments are allowed :) return None, [], {ListType.ALLOW: triggered} - # Something is disallowed. - if ".py" in not_allowed: - # Provide a pastebin link for .py files. - ctx.dm_embed = PY_EMBED_DESCRIPTION - elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}: - # Work around Discord auto-conversion of messages longer than 2000 chars to .txt - cmd_channel = bot.instance.get_channel(Channels.bot_commands) - ctx.dm_embed = TXT_EMBED_DESCRIPTION.format( - blocked_extension=txt_extensions.pop(), - cmd_channel_mention=cmd_channel.mention - ) - else: - meta_channel = bot.instance.get_channel(Channels.meta) - if not self._whitelisted_description: - self._whitelisted_description = ', '.join( - filter_.content for filter_ in self[ListType.ALLOW].filters.values() + # At this point, something is disallowed. + if ctx.event != Event.SNEKBOX: # Don't post the embed if it's a snekbox response. + if ".py" in not_allowed: + # Provide a pastebin link for .py files. + ctx.dm_embed = PY_EMBED_DESCRIPTION + elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}: + # Work around Discord auto-conversion of messages longer than 2000 chars to .txt + cmd_channel = bot.instance.get_channel(Channels.bot_commands) + ctx.dm_embed = TXT_EMBED_DESCRIPTION.format( + blocked_extension=txt_extensions.pop(), + cmd_channel_mention=cmd_channel.mention + ) + else: + meta_channel = bot.instance.get_channel(Channels.meta) + if not self._whitelisted_description: + self._whitelisted_description = ', '.join( + filter_.content for filter_ in self[ListType.ALLOW].filters.values() + ) + ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=self._whitelisted_description, + blocked_extensions_str=", ".join(not_allowed), + meta_channel_mention=meta_channel.mention, ) - ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=self._whitelisted_description, - blocked_extensions_str=", ".join(not_allowed), - meta_channel_mention=meta_channel.mention, - ) ctx.matches += not_allowed.values() - return self[ListType.ALLOW].defaults.actions, [f"`{ext}`" for ext in not_allowed], {ListType.ALLOW: triggered} + ctx.blocked_exts |= set(not_allowed) + actions = self[ListType.ALLOW].defaults.actions if ctx.event != Event.SNEKBOX else None + return actions, [f"`{ext}`" for ext in not_allowed], {ListType.ALLOW: triggered} diff --git a/bot/exts/filtering/_filter_lists/invite.py b/bot/exts/filtering/_filter_lists/invite.py index bd0eaa122..b9732a6dc 100644 --- a/bot/exts/filtering/_filter_lists/invite.py +++ b/bot/exts/filtering/_filter_lists/invite.py @@ -37,7 +37,7 @@ class InviteList(FilterList[InviteFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py index f5da28bb5..0c591ac3b 100644 --- a/bot/exts/filtering/_filter_lists/token.py +++ b/bot/exts/filtering/_filter_lists/token.py @@ -32,7 +32,7 @@ class TokensList(FilterList[TokenFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filters/unique/discord_token.py b/bot/exts/filtering/_filters/unique/discord_token.py index 6174ee30b..f4b9cc741 100644 --- a/bot/exts/filtering/_filters/unique/discord_token.py +++ b/bot/exts/filtering/_filters/unique/discord_token.py @@ -61,7 +61,7 @@ class DiscordTokenFilter(UniqueFilter): """Scans messages for potential discord client tokens and removes them.""" name = "discord_token" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) extra_fields_type = ExtraDiscordTokenSettings @property diff --git a/bot/exts/filtering/_filters/unique/everyone.py b/bot/exts/filtering/_filters/unique/everyone.py index a32e67cc5..e49ede82f 100644 --- a/bot/exts/filtering/_filters/unique/everyone.py +++ b/bot/exts/filtering/_filters/unique/everyone.py @@ -16,7 +16,7 @@ class EveryoneFilter(UniqueFilter): """Filter messages which contain `@everyone` and `@here` tags outside a codeblock.""" name = "everyone" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) async def triggered_on(self, ctx: FilterContext) -> bool: """Search for the filter's content within a given context.""" diff --git a/bot/exts/filtering/_filters/unique/webhook.py b/bot/exts/filtering/_filters/unique/webhook.py index 965ef42eb..4e1e2e44d 100644 --- a/bot/exts/filtering/_filters/unique/webhook.py +++ b/bot/exts/filtering/_filters/unique/webhook.py @@ -22,7 +22,7 @@ class WebhookFilter(UniqueFilter): """Scan messages to detect Discord webhooks links.""" name = "webhook" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) @property def mod_log(self) -> ModLog | None: diff --git a/bot/exts/filtering/_settings_types/actions/remove_context.py b/bot/exts/filtering/_settings_types/actions/remove_context.py index 7ead88818..5ec2613f4 100644 --- a/bot/exts/filtering/_settings_types/actions/remove_context.py +++ b/bot/exts/filtering/_settings_types/actions/remove_context.py @@ -28,8 +28,8 @@ async def upload_messages_attachments(ctx: FilterContext, messages: list[Message return destination = messages[0].guild.get_channel(Channels.attachment_log) for message in messages: - if message.attachments and message.id not in ctx.attachments: - ctx.attachments[message.id] = await send_attachments(message, destination, link_large=False) + if message.attachments and message.id not in ctx.uploaded_attachments: + ctx.uploaded_attachments[message.id] = await send_attachments(message, destination, link_large=False) class RemoveContext(ActionEntry): diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 157906d6b..8cd2864a9 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -59,10 +59,10 @@ async def _build_alert_message_content(ctx: FilterContext, current_message_lengt # For multiple messages and those with attachments or excessive newlines, use the logs API if ctx.messages_deletion and ctx.upload_deletion_logs and any(( ctx.related_messages, - len(ctx.attachments) > 0, + len(ctx.uploaded_attachments) > 0, ctx.content.count('\n') > 15 )): - url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.attachments) + url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.uploaded_attachments) return f"A complete log of the offending messages can be found [here]({url})" alert_content = escape_markdown(ctx.content) @@ -70,7 +70,7 @@ async def _build_alert_message_content(ctx: FilterContext, current_message_lengt if len(alert_content) > remaining_chars: if ctx.messages_deletion and ctx.upload_deletion_logs: - url = await upload_log([ctx.message], bot.instance.user.id, ctx.attachments) + url = await upload_log([ctx.message], bot.instance.user.id, ctx.uploaded_attachments) log_site_msg = f"The full message can be found [here]({url})" # 7 because that's the length of "[...]\n\n" return alert_content[:remaining_chars - (7 + len(log_site_msg))] + "[...]\n\n" + log_site_msg diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index c4417e5e0..2a7f8f81f 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -40,6 +40,7 @@ from bot.exts.filtering._ui.ui import ( ) from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable from bot.exts.moderation.infraction.infractions import COMP_BAN_DURATION, COMP_BAN_REASON +from bot.exts.utils.snekbox._io import FileAttachment from bot.log import get_logger from bot.pagination import LinePaginator from bot.utils.channel import is_mod_channel @@ -251,24 +252,30 @@ class Filtering(Cog): ctx = FilterContext(Event.NICKNAME, member, None, member.display_name, None) await self._check_bad_name(ctx) - async def filter_snekbox_output(self, snekbox_result: str, msg: Message) -> bool: + async def filter_snekbox_output( + self, stdout: str, files: list[FileAttachment], msg: Message + ) -> tuple[bool, set[str]]: """ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. Also requires the original message, to check whether to filter and for alerting. Any action (deletion, infraction) will be applied in the context of the original message. - Returns whether a filter was triggered or not. + Returns whether the output should be blocked, as well as a list of blocked file extensions. """ - ctx = FilterContext.from_message(Event.MESSAGE, msg).replace(content=snekbox_result) + content = stdout + if files: # Filter the filenames as well. + content += "\n\n" + "\n".join(file.filename for file in files) + ctx = FilterContext.from_message(Event.SNEKBOX, msg).replace(content=content, attachments=files) + result_actions, list_messages, triggers = await self._resolve_action(ctx) if result_actions: await result_actions.action(ctx) if ctx.send_alert: await self._send_alert(ctx, list_messages) - self._increment_stats(triggers) - return result_actions is not None + self._increment_stats(triggers) + return result_actions is not None, ctx.blocked_exts # endregion # region: blacklist commands diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index 567fe6c24..d7e8bc93c 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -14,9 +14,8 @@ from pydis_core.utils import interactions from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from bot.bot import Bot -from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES, URLs +from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, URLs from bot.decorators import redirect_output -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.filtering._filter_lists.extension import TXT_LIKE_FILES from bot.exts.help_channels._channel import is_help_forum_post from bot.exts.utils.snekbox._eval import EvalJob, EvalResult @@ -288,37 +287,22 @@ class Snekbox(Cog): return output, paste_link - def get_extensions_whitelist(self) -> set[str]: - """Return a set of whitelisted file extensions.""" - return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES - - def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles: + def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(ctx.author, "roles") and any(role.id in STAFF_PARTNERS_COMMUNITY_ROLES for role in ctx.author.roles): - return FilteredFiles(files, []) - # Ignore code jam channels - if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME: - return FilteredFiles(files, []) - - # Get whitelisted extensions - whitelist = self.get_extensions_whitelist() - # Filter files into allowed and blocked blocked = [] allowed = [] for file in files: - if file.suffix in whitelist: - allowed.append(file) - else: + if file.suffix in blocked_exts: blocked.append(file) + else: + allowed.append(file) if blocked: blocked_str = ", ".join(f.suffix for f in blocked) log.info( f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", - extra={"attachment_list": [f.path for f in files]} + extra={"attachment_list": [f.filename for f in files]} ) return FilteredFiles(allowed, blocked) @@ -365,31 +349,8 @@ class Snekbox(Cog): else: self.bot.stats.incr("snekbox.python.success") - # Filter file extensions - allowed, blocked = self._filter_files(ctx, result.files) - # Also scan failed files for blocked extensions - failed_files = [FileAttachment(name, b"") for name in result.failed_files] - blocked.extend(self._filter_files(ctx, failed_files).blocked) - # Add notice if any files were blocked - if blocked: - blocked_sorted = sorted(set(f.suffix for f in blocked)) - # Only no extension - if len(blocked_sorted) == 1 and blocked_sorted[0] == "": - blocked_msg = "Files with no extension can't be uploaded." - # Both - elif "" in blocked_sorted: - blocked_str = ", ".join(ext for ext in blocked_sorted if ext) - blocked_msg = ( - f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" - ) - else: - blocked_str = ", ".join(blocked_sorted) - blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" - - msg += f"\n{Emojis.failed_file} {blocked_msg}" - # Split text files - text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES] + text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES] # Inline until budget, then upload to paste service # Budget is shared with stdout, so subtract what we've already used budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) @@ -417,8 +378,35 @@ class Snekbox(Cog): budget_chars -= len(file_text) filter_cog: Filtering | None = self.bot.get_cog("Filtering") - if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)): - return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + blocked_exts = set() + # Include failed files in the scan. + failed_files = [FileAttachment(name, b"") for name in result.failed_files] + total_files = result.files + failed_files + if filter_cog: + block_output, blocked_exts = await filter_cog.filter_snekbox_output(msg, total_files, ctx.message) + if block_output: + return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + + # Filter file extensions + allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) + blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked) + # Add notice if any files were blocked + if blocked: + blocked_sorted = sorted(set(f.suffix for f in blocked)) + # Only no extension + if len(blocked_sorted) == 1 and blocked_sorted[0] == "": + blocked_msg = "Files with no extension can't be uploaded." + # Both + elif "" in blocked_sorted: + blocked_str = ", ".join(ext for ext in blocked_sorted if ext) + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) + else: + blocked_str = ", ".join(blocked_sorted) + blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + + msg += f"\n{Emojis.failed_file} {blocked_msg}" # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py index 9be396335..a45ecec1a 100644 --- a/bot/exts/utils/snekbox/_io.py +++ b/bot/exts/utils/snekbox/_io.py @@ -53,23 +53,23 @@ def normalize_discord_file_name(name: str) -> str: class FileAttachment: """File Attachment from Snekbox eval.""" - path: str + filename: str content: bytes def __repr__(self) -> str: """Return the content as a string.""" content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content - return f"FileAttachment(path={self.path!r}, content={content})" + return f"FileAttachment(path={self.filename!r}, content={content})" @property def suffix(self) -> str: """Return the file suffix.""" - return PurePosixPath(self.path).suffix + return PurePosixPath(self.filename).suffix @property def name(self) -> str: """Return the file name.""" - return PurePosixPath(self.path).name + return PurePosixPath(self.filename).name @classmethod def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: @@ -92,7 +92,7 @@ class FileAttachment: content = content.encode("utf-8") return { - "path": self.path, + "path": self.filename, "content": b64encode(content).decode("ascii"), } diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 0ad41116d..351daa0b4 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -45,9 +45,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) @@ -62,9 +62,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_illegal_extension(self): """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) @@ -72,11 +72,11 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_python_file_redirect_embed_description(self): """A message containing a .py file should result in an embed redirecting the user to our paste site.""" attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) - self.assertEqual(self.ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) @patch("bot.instance", BOT) async def test_txt_file_redirect_embed_description(self): @@ -91,12 +91,12 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.TXT_EMBED_DESCRIPTION.format( blocked_extension=disallowed_extension, ) @@ -106,13 +106,13 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) meta_channel = BOT.get_channel(Channels.meta) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.DISALLOWED_EMBED_DESCRIPTION.format( joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", @@ -134,6 +134,6 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{ext}") for ext in extensions] - result = await self.filter_list.actions_for(self.ctx) + ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) + result = await self.filter_list.actions_for(ctx) self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") -- cgit v1.2.3 From caf6bd4377e3f5a6426bc32df4bf711a897b62b1 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Sun, 26 Mar 2023 23:13:50 +0300 Subject: Rename additional_field to additional_settings --- bot/exts/filtering/_filters/filter.py | 4 +++- bot/exts/filtering/filtering.py | 4 ++-- tests/bot/exts/filtering/test_discord_token_filter.py | 2 +- tests/bot/exts/filtering/test_extension_filter.py | 2 +- tests/bot/exts/filtering/test_token_filter.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index b5f4c127a..526d2fe67 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -31,7 +31,9 @@ class Filter(FieldRequiring): self.updated_at = arrow.get(filter_data["updated_at"]) self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults) if self.extra_fields_type: - self.extra_fields = self.extra_fields_type.parse_raw(filter_data["additional_field"] or "{}") # noqa: P103 + self.extra_fields = self.extra_fields_type.parse_raw( + filter_data["additional_settings"] or "{}" # noqa: P103 + ) else: self.extra_fields = None diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 2a7f8f81f..efea57a6a 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -1177,7 +1177,7 @@ class Filtering(Cog): description = description or None payload = { "filter_list": list_id, "content": content, "description": description, - "additional_field": json.dumps(filter_settings), **settings + "additional_settings": json.dumps(filter_settings), **settings } response = await bot.instance.api_client.post('bot/filter/filters', json=to_serializable(payload)) new_filter = filter_list.add_filter(list_type, response) @@ -1220,7 +1220,7 @@ class Filtering(Cog): description = description or None payload = { "filter_list": list_id, "content": content, "description": description, - "additional_field": json.dumps(filter_settings), **settings + "additional_settings": json.dumps(filter_settings), **settings } response = await bot.instance.api_client.patch( f'bot/filter/filters/{filter_.id}', json=to_serializable(payload) diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py index ef124e6ff..4d7e69bdc 100644 --- a/tests/bot/exts/filtering/test_discord_token_filter.py +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -22,7 +22,7 @@ class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): "content": "discord_token", "description": None, "settings": {}, - "additional_field": "{}", # noqa: P103 + "additional_settings": "{}", # noqa: P103 "created_at": now, "updated_at": now }) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 351daa0b4..52506d0be 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -25,7 +25,7 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): for i, filter_content in enumerate(self.whitelist, start=1): filters.append({ "id": i, "content": filter_content, "description": None, "settings": {}, - "additional_field": "{}", "created_at": now, "updated_at": now # noqa: P103 + "additional_settings": "{}", "created_at": now, "updated_at": now # noqa: P103 }) self.filter_list.add_list({ "id": 1, diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py index 0dfc8ae9f..82cc6b67e 100644 --- a/tests/bot/exts/filtering/test_token_filter.py +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -40,7 +40,7 @@ class TokenFilterTests(unittest.IsolatedAsyncioTestCase): "content": pattern, "description": None, "settings": {}, - "additional_field": "{}", # noqa: P103 + "additional_settings": "{}", # noqa: P103 "created_at": now, "updated_at": now }) -- cgit v1.2.3 From f01883682f4d333382d8e8a89363dc906fe86342 Mon Sep 17 00:00:00 2001 From: mbaruh Date: Tue, 28 Mar 2023 01:12:25 +0300 Subject: Support custom value representation in filtering UI Adds the `CustomIOField` class which can be used as a base for wrappers that store a value with a customized way to process the user input and to present the value in the UI. --- .../actions/infraction_and_notification.py | 52 ++++++++++++++--- bot/exts/filtering/_ui/filter.py | 10 ++-- bot/exts/filtering/_ui/search.py | 4 +- bot/exts/filtering/_utils.py | 66 ++++++++++++++++++++-- bot/exts/filtering/filtering.py | 8 +-- tests/bot/exts/filtering/test_settings_entries.py | 16 +++--- 6 files changed, 127 insertions(+), 29 deletions(-) (limited to 'tests') diff --git a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py index 5ae4901b6..e3df47029 100644 --- a/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py +++ b/bot/exts/filtering/_settings_types/actions/infraction_and_notification.py @@ -1,9 +1,9 @@ -from datetime import timedelta from enum import Enum, auto from typing import ClassVar import arrow import discord.abc +from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Member, User from discord.errors import Forbidden from pydantic import validator @@ -15,7 +15,8 @@ import bot as bot_module from bot.constants import Channels from bot.exts.filtering._filter_context import FilterContext from bot.exts.filtering._settings_types.settings_entry import ActionEntry -from bot.exts.filtering._utils import FakeContext +from bot.exts.filtering._utils import CustomIOField, FakeContext +from bot.utils.time import humanize_delta, parse_duration_string, relativedelta_to_timedelta log = get_logger(__name__) @@ -31,6 +32,38 @@ passive_form = { } +class InfractionDuration(CustomIOField): + """A field that converts a string to a duration and presents it in a human-readable format.""" + + @classmethod + def process_value(cls, v: str | relativedelta) -> relativedelta: + """ + Transform the given string into a relativedelta. + + Raise a ValueError if the conversion is not possible. + """ + if isinstance(v, relativedelta): + return v + + try: + v = float(v) + except ValueError: # Not a float. + if not (delta := parse_duration_string(v)): + raise ValueError(f"`{v}` is not a valid duration string.") + else: + delta = relativedelta(seconds=float(v)).normalized() + + return delta + + def serialize(self) -> float: + """The serialized value is the total number of seconds this duration represents.""" + return relativedelta_to_timedelta(self.value).total_seconds() + + def __str__(self): + """Represent the stored duration in a human-readable format.""" + return humanize_delta(self.value, max_units=2) if self.value else "Permanent" + + class Infraction(Enum): """An enumeration of infraction types. The lower the value, the higher it is on the hierarchy.""" @@ -53,7 +86,7 @@ class Infraction(Enum): message: discord.Message, channel: discord.abc.GuildChannel | discord.DMChannel, alerts_channel: discord.TextChannel, - duration: float, + duration: InfractionDuration, reason: str ) -> None: """Invokes the command matching the infraction name.""" @@ -72,7 +105,7 @@ class Infraction(Enum): if self.name in ("KICK", "WARNING", "WATCH", "NOTE"): await command(ctx, user, reason=reason or None) else: - duration = arrow.utcnow() + timedelta(seconds=duration) if duration else None + duration = arrow.utcnow().datetime + duration.value if duration.value else None await command(ctx, user, duration, reason=reason or None) @@ -91,7 +124,10 @@ class InfractionAndNotification(ActionEntry): "the harsher one will be applied (by type or duration).\n\n" "Valid infraction types in order of harshness: " ) + ", ".join(infraction.name for infraction in Infraction), - "infraction_duration": "How long the infraction should last for in seconds. 0 for permanent.", + "infraction_duration": ( + "How long the infraction should last for in seconds. 0 for permanent. " + "Also supports durations as in an infraction invocation (such as `10d`)." + ), "infraction_reason": "The reason delivered with the infraction.", "infraction_channel": ( "The channel ID in which to invoke the infraction (and send the confirmation message). " @@ -106,7 +142,7 @@ class InfractionAndNotification(ActionEntry): dm_embed: str infraction_type: Infraction infraction_reason: str - infraction_duration: float + infraction_duration: InfractionDuration infraction_channel: int @validator("infraction_type", pre=True) @@ -184,8 +220,10 @@ class InfractionAndNotification(ActionEntry): result = other.copy() other = self else: + now = arrow.utcnow().datetime if self.infraction_duration is None or ( - other.infraction_duration is not None and self.infraction_duration > other.infraction_duration + other.infraction_duration is not None + and now + self.infraction_duration.value > now + other.infraction_duration.value ): result = self.copy() else: diff --git a/bot/exts/filtering/_ui/filter.py b/bot/exts/filtering/_ui/filter.py index 1ef25f17a..5b23b71e9 100644 --- a/bot/exts/filtering/_ui/filter.py +++ b/bot/exts/filtering/_ui/filter.py @@ -33,7 +33,7 @@ def build_filter_repr_dict( default_setting_values = {} for settings_group in filter_list[list_type].defaults: for _, setting in settings_group.items(): - default_setting_values.update(to_serializable(setting.dict())) + default_setting_values.update(to_serializable(setting.dict(), ui_repr=True)) # Add overrides. It's done in this way to preserve field order, since the filter won't have all settings. total_values = {} @@ -434,10 +434,10 @@ def description_and_settings_converter( return description, settings, filter_settings -def filter_serializable_overrides(filter_: Filter) -> tuple[dict, dict]: - """Get a serializable version of the filter's overrides.""" +def filter_overrides_for_ui(filter_: Filter) -> tuple[dict, dict]: + """Get the filter's overrides in a format that can be displayed in the UI.""" overrides_values, extra_fields_overrides = filter_.overrides - return to_serializable(overrides_values), to_serializable(extra_fields_overrides) + return to_serializable(overrides_values, ui_repr=True), to_serializable(extra_fields_overrides, ui_repr=True) def template_settings( @@ -461,4 +461,4 @@ def template_settings( raise BadArgument( f"The template filter name is {filter_.name!r}, but the target filter is {filter_type.name!r}" ) - return filter_serializable_overrides(filter_) + return filter_.overrides diff --git a/bot/exts/filtering/_ui/search.py b/bot/exts/filtering/_ui/search.py index d553c28ea..dba7f3cea 100644 --- a/bot/exts/filtering/_ui/search.py +++ b/bot/exts/filtering/_ui/search.py @@ -10,7 +10,7 @@ from discord.ext.commands import BadArgument from bot.exts.filtering._filter_lists import FilterList, ListType from bot.exts.filtering._filters.filter import Filter from bot.exts.filtering._settings_types.settings_entry import SettingsEntry -from bot.exts.filtering._ui.filter import filter_serializable_overrides +from bot.exts.filtering._ui.filter import filter_overrides_for_ui from bot.exts.filtering._ui.ui import ( COMPONENT_TIMEOUT, CustomCallbackSelect, EditBaseView, MISSING, SETTINGS_DELIMITER, parse_value, populate_embed_from_dict @@ -114,7 +114,7 @@ def template_settings( if filter_type and not isinstance(filter_, filter_type): raise BadArgument(f"The filter with ID `{filter_id}` is not of type {filter_type.name!r}.") - settings, filter_settings = filter_serializable_overrides(filter_) + settings, filter_settings = filter_overrides_for_ui(filter_) return settings, filter_settings, type(filter_) diff --git a/bot/exts/filtering/_utils.py b/bot/exts/filtering/_utils.py index da433330f..a43233f20 100644 --- a/bot/exts/filtering/_utils.py +++ b/bot/exts/filtering/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import importlib.util import inspect @@ -12,6 +14,7 @@ from typing import Any, Iterable, TypeVar, Union, get_args, get_origin import discord import regex from discord.ext.commands import Command +from typing_extensions import Self import bot from bot.bot import Bot @@ -24,6 +27,8 @@ ZALGO_RE = regex.compile(rf"[\p{{NONSPACING MARK}}\p{{ENCLOSING MARK}}--[{VARIAT T = TypeVar('T') +Serializable = Union[bool, int, float, str, list, dict, None] + def subclasses_in_package(package: str, prefix: str, parent: T) -> set[T]: """Return all the subclasses of class `parent`, found in the top-level of `package`, given by absolute path.""" @@ -62,8 +67,13 @@ def past_tense(word: str) -> str: return word + "ed" -def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None]: - """Convert the item into an object that can be converted to JSON.""" +def to_serializable(item: Any, *, ui_repr: bool = False) -> Serializable: + """ + Convert the item into an object that can be converted to JSON. + + `ui_repr` dictates whether to use the UI representation of `CustomIOField` instances (if any) + or the DB-oriented representation. + """ if isinstance(item, (bool, int, float, str, type(None))): return item if isinstance(item, dict): @@ -71,10 +81,12 @@ def to_serializable(item: Any) -> Union[bool, int, float, str, list, dict, None] for key, value in item.items(): if not isinstance(key, (bool, int, float, str, type(None))): key = str(key) - result[key] = to_serializable(value) + result[key] = to_serializable(value, ui_repr=ui_repr) return result if isinstance(item, Iterable): - return [to_serializable(subitem) for subitem in item] + return [to_serializable(subitem, ui_repr=ui_repr) for subitem in item] + if not ui_repr and hasattr(item, "serialize"): + return item.serialize() return str(item) @@ -222,3 +234,49 @@ class FakeContext: async def send(self, *args, **kwargs) -> discord.Message: """A wrapper for channel.send.""" return await self.channel.send(*args, **kwargs) + + +class CustomIOField: + """ + A class to be used as a data type in SettingEntry subclasses. + + Its subclasses can have custom methods to read and represent the value, which will be used by the UI. + """ + + def __init__(self, value: Any): + self.value = self.process_value(value) + + @classmethod + def __get_validators__(cls): + """Boilerplate for Pydantic.""" + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Self: + """Takes the given value and returns a class instance with that value.""" + if isinstance(v, CustomIOField): + return cls(v.value) + + return cls(v) + + def __eq__(self, other: CustomIOField): + if not isinstance(other, CustomIOField): + return NotImplemented + return self.value == other.value + + @classmethod + def process_value(cls, v: str) -> Any: + """ + Perform any necessary transformations before the value is stored in a new instance. + + Override this method to customize the input behavior. + """ + return v + + def serialize(self) -> Serializable: + """Override this method to customize how the value will be serialized.""" + return self.value + + def __str__(self): + """Override this method to change how the value will be displayed by the UI.""" + return self.value diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 58d2f125e..8fd4ddb13 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -31,7 +31,7 @@ from bot.exts.filtering._filters.filter import Filter, UniqueFilter from bot.exts.filtering._settings import ActionSettings from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction from bot.exts.filtering._ui.filter import ( - build_filter_repr_dict, description_and_settings_converter, filter_serializable_overrides, populate_embed_from_dict + build_filter_repr_dict, description_and_settings_converter, filter_overrides_for_ui, populate_embed_from_dict ) from bot.exts.filtering._ui.filter_list import FilterListAddView, FilterListEditView, settings_converter from bot.exts.filtering._ui.search import SearchEditView, search_criteria_converter @@ -383,7 +383,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result - overrides_values, extra_fields_overrides = filter_serializable_overrides(filter_) + overrides_values, extra_fields_overrides = filter_overrides_for_ui(filter_) all_settings_repr_dict = build_filter_repr_dict( filter_list, list_type, type(filter_), overrides_values, extra_fields_overrides @@ -493,7 +493,7 @@ class Filtering(Cog): return filter_, filter_list, list_type = result filter_type = type(filter_) - settings, filter_settings = filter_serializable_overrides(filter_) + settings, filter_settings = filter_overrides_for_ui(filter_) description, new_settings, new_filter_settings = description_and_settings_converter( filter_list, list_type, filter_type, @@ -734,7 +734,7 @@ class Filtering(Cog): setting_values = {} for settings_group in filter_list[list_type].defaults: for _, setting in settings_group.items(): - setting_values.update(to_serializable(setting.dict())) + setting_values.update(to_serializable(setting.dict(), ui_repr=True)) embed = Embed(colour=Colour.blue()) populate_embed_from_dict(embed, setting_values) diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py index c5f0152b0..3ae0b5ab5 100644 --- a/tests/bot/exts/filtering/test_settings_entries.py +++ b/tests/bot/exts/filtering/test_settings_entries.py @@ -1,7 +1,9 @@ import unittest from bot.exts.filtering._filter_context import Event, FilterContext -from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification +from bot.exts.filtering._settings_types.actions.infraction_and_notification import ( + Infraction, InfractionAndNotification, InfractionDuration +) from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM @@ -154,7 +156,7 @@ class FilterTests(unittest.TestCase): infraction1 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="hi", - infraction_duration=10, + infraction_duration=InfractionDuration(10), dm_content="how", dm_embed="what is", infraction_channel=0 @@ -162,7 +164,7 @@ class FilterTests(unittest.TestCase): infraction2 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="there", - infraction_duration=20, + infraction_duration=InfractionDuration(20), dm_content="are you", dm_embed="your name", infraction_channel=0 @@ -175,7 +177,7 @@ class FilterTests(unittest.TestCase): { "infraction_type": Infraction.TIMEOUT, "infraction_reason": "there", - "infraction_duration": 20.0, + "infraction_duration": InfractionDuration(20.0), "dm_content": "are you", "dm_embed": "your name", "infraction_channel": 0 @@ -187,7 +189,7 @@ class FilterTests(unittest.TestCase): infraction1 = InfractionAndNotification( infraction_type="TIMEOUT", infraction_reason="hi", - infraction_duration=20, + infraction_duration=InfractionDuration(20), dm_content="", dm_embed="", infraction_channel=0 @@ -195,7 +197,7 @@ class FilterTests(unittest.TestCase): infraction2 = InfractionAndNotification( infraction_type="BAN", infraction_reason="", - infraction_duration=10, + infraction_duration=InfractionDuration(10), dm_content="there", dm_embed="", infraction_channel=0 @@ -208,7 +210,7 @@ class FilterTests(unittest.TestCase): { "infraction_type": Infraction.BAN, "infraction_reason": "", - "infraction_duration": 10.0, + "infraction_duration": InfractionDuration(10), "dm_content": "there", "dm_embed": "", "infraction_channel": 0 -- cgit v1.2.3 From 7386e6b1d6ebaea727971fd25ff12840b0c7a435 Mon Sep 17 00:00:00 2001 From: Boris Muratov <8bee278@gmail.com> Date: Wed, 5 Apr 2023 03:15:17 +0300 Subject: Fix test --- tests/bot/exts/filtering/test_extension_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 52506d0be..800fad3a0 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -115,7 +115,7 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): ctx.dm_embed, extension.DISALLOWED_EMBED_DESCRIPTION.format( joined_whitelist=", ".join(self.whitelist), - blocked_extensions_str=".disallowed", + joined_blacklist=".disallowed", meta_channel_mention=meta_channel.mention ) ) -- cgit v1.2.3 From 076c2e910cdac0f168084caaf7d1331fc40a638a Mon Sep 17 00:00:00 2001 From: Boris Muratov <8bee278@gmail.com> Date: Thu, 6 Apr 2023 01:10:33 +0300 Subject: Adjust to site using dicts in additional_settings instead of JSON strings This is to make sure that the unique constraint always compares between dicts instead of sometimes dealing with nulls. --- bot/exts/filtering/_filters/filter.py | 4 +--- bot/exts/filtering/filtering.py | 4 ++-- tests/bot/exts/filtering/test_discord_token_filter.py | 2 +- tests/bot/exts/filtering/test_extension_filter.py | 2 +- tests/bot/exts/filtering/test_token_filter.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/bot/exts/filtering/_filters/filter.py b/bot/exts/filtering/_filters/filter.py index 128e84645..2b8f8d5d4 100644 --- a/bot/exts/filtering/_filters/filter.py +++ b/bot/exts/filtering/_filters/filter.py @@ -31,9 +31,7 @@ class Filter(FieldRequiring): self.updated_at = arrow.get(filter_data["updated_at"]) self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults) if self.extra_fields_type: - self.extra_fields = self.extra_fields_type.parse_raw( - filter_data["additional_settings"] or "{}" # noqa: P103 - ) + self.extra_fields = self.extra_fields_type.parse_obj(filter_data["additional_settings"]) else: self.extra_fields = None diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index 8fd4ddb13..392428bb0 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -1177,7 +1177,7 @@ class Filtering(Cog): description = description or None payload = { "filter_list": list_id, "content": content, "description": description, - "additional_settings": json.dumps(filter_settings), **settings + "additional_settings": filter_settings, **settings } response = await bot.instance.api_client.post('bot/filter/filters', json=to_serializable(payload)) new_filter = filter_list.add_filter(list_type, response) @@ -1220,7 +1220,7 @@ class Filtering(Cog): description = description or None payload = { "filter_list": list_id, "content": content, "description": description, - "additional_settings": json.dumps(filter_settings), **settings + "additional_settings": filter_settings, **settings } response = await bot.instance.api_client.patch( f'bot/filter/filters/{filter_.id}', json=to_serializable(payload) diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py index 4d7e69bdc..a5cddf8d9 100644 --- a/tests/bot/exts/filtering/test_discord_token_filter.py +++ b/tests/bot/exts/filtering/test_discord_token_filter.py @@ -22,7 +22,7 @@ class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase): "content": "discord_token", "description": None, "settings": {}, - "additional_settings": "{}", # noqa: P103 + "additional_settings": {}, "created_at": now, "updated_at": now }) diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 800fad3a0..827d267d2 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -25,7 +25,7 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): for i, filter_content in enumerate(self.whitelist, start=1): filters.append({ "id": i, "content": filter_content, "description": None, "settings": {}, - "additional_settings": "{}", "created_at": now, "updated_at": now # noqa: P103 + "additional_settings": {}, "created_at": now, "updated_at": now # noqa: P103 }) self.filter_list.add_list({ "id": 1, diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py index 82cc6b67e..03fa6b4b9 100644 --- a/tests/bot/exts/filtering/test_token_filter.py +++ b/tests/bot/exts/filtering/test_token_filter.py @@ -40,7 +40,7 @@ class TokenFilterTests(unittest.IsolatedAsyncioTestCase): "content": pattern, "description": None, "settings": {}, - "additional_settings": "{}", # noqa: P103 + "additional_settings": {}, "created_at": now, "updated_at": now }) -- cgit v1.2.3