aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2023-04-09 00:34:55 +0300
committerGravatar GitHub <[email protected]>2023-04-09 00:34:55 +0300
commitc74a3c645d23d22206182813478f6c9b74812ed8 (patch)
tree62b648f92c2607a5ed94f90a86b09698d16f8c98 /tests
parentFix: use nomination.user_id instead of id in get_or_fetch_member (diff)
parentMerge pull request #2517 from python-discord/thread_filter (diff)
Merge branch 'main' into 2302-activity-in-reviews
Diffstat (limited to 'tests')
-rw-r--r--tests/_autospec.py3
-rw-r--r--tests/bot/exts/filtering/__init__.py (renamed from tests/bot/exts/filters/__init__.py)0
-rw-r--r--tests/bot/exts/filtering/test_discord_token_filter.py276
-rw-r--r--tests/bot/exts/filtering/test_extension_filter.py139
-rw-r--r--tests/bot/exts/filtering/test_settings.py20
-rw-r--r--tests/bot/exts/filtering/test_settings_entries.py218
-rw-r--r--tests/bot/exts/filtering/test_token_filter.py49
-rw-r--r--tests/bot/exts/filters/test_antimalware.py202
-rw-r--r--tests/bot/exts/filters/test_antispam.py35
-rw-r--r--tests/bot/exts/filters/test_filtering.py40
-rw-r--r--tests/bot/exts/filters/test_token_remover.py409
-rw-r--r--tests/bot/exts/utils/snekbox/test_snekbox.py8
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py69
-rw-r--r--tests/bot/rules/test_burst.py54
-rw-r--r--tests/bot/rules/test_burst_shared.py57
-rw-r--r--tests/bot/rules/test_chars.py64
-rw-r--r--tests/bot/rules/test_discord_emojis.py73
-rw-r--r--tests/bot/rules/test_duplicates.py64
-rw-r--r--tests/bot/rules/test_links.py67
-rw-r--r--tests/bot/rules/test_mentions.py131
-rw-r--r--tests/bot/rules/test_newlines.py102
-rw-r--r--tests/bot/rules/test_role_mentions.py55
-rw-r--r--tests/helpers.py8
24 files changed, 712 insertions, 1507 deletions
diff --git a/tests/_autospec.py b/tests/_autospec.py
index ee2fc1973..ecff6bcbe 100644
--- a/tests/_autospec.py
+++ b/tests/_autospec.py
@@ -1,5 +1,6 @@
import contextlib
import functools
+import pkgutil
import unittest.mock
from typing import Callable
@@ -51,7 +52,7 @@ def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs)
# Import the target if it's a string.
# This is to support both object and string targets like patch.multiple.
if type(target) is str:
- target = unittest.mock._importer(target)
+ target = pkgutil.resolve_name(target)
def decorator(func):
for attribute in attributes:
diff --git a/tests/bot/exts/filters/__init__.py b/tests/bot/exts/filtering/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/bot/exts/filters/__init__.py
+++ b/tests/bot/exts/filtering/__init__.py
diff --git a/tests/bot/exts/filtering/test_discord_token_filter.py b/tests/bot/exts/filtering/test_discord_token_filter.py
new file mode 100644
index 000000000..a5cddf8d9
--- /dev/null
+++ b/tests/bot/exts/filtering/test_discord_token_filter.py
@@ -0,0 +1,276 @@
+import unittest
+from re import Match
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import arrow
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.unique import discord_token
+from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter, Token
+from tests.helpers import MockBot, MockMember, MockMessage, MockTextChannel, autospec
+
+
+class DiscordTokenFilterTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the DiscordTokenFilter class."""
+
+ def setUp(self):
+ """Adds the filter, a bot, and a message to the instance for usage in tests."""
+ now = arrow.utcnow().timestamp()
+ self.filter = DiscordTokenFilter({
+ "id": 1,
+ "content": "discord_token",
+ "description": None,
+ "settings": {},
+ "additional_settings": {},
+ "created_at": now,
+ "updated_at": now
+ })
+
+ self.msg = MockMessage(id=555, content="hello world")
+ self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
+
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.msg)
+
+ def test_extract_user_id_valid(self):
+ """Should consider user IDs valid if they decode into an integer ID."""
+ id_pairs = (
+ ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332),
+ ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904),
+ ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641),
+ )
+
+ for token_id, user_id in id_pairs:
+ with self.subTest(token_id=token_id):
+ result = DiscordTokenFilter.extract_user_id(token_id)
+ self.assertEqual(result, user_id)
+
+ def test_extract_user_id_invalid(self):
+ """Should consider non-digit and non-ASCII IDs invalid."""
+ ids = (
+ ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
+ ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
+ ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
+ ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
+ ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for user_id, msg in ids:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.extract_user_id(user_id)
+ self.assertIsNone(result)
+
+ def test_is_valid_timestamp_valid(self):
+ """Should consider timestamps valid if they're greater than the Discord epoch."""
+ timestamps = (
+ "XsyRkw",
+ "Xrim9Q",
+ "XsyR-w",
+ "XsySD_",
+ "Dn9r_A",
+ )
+
+ for timestamp in timestamps:
+ with self.subTest(timestamp=timestamp):
+ result = DiscordTokenFilter.is_valid_timestamp(timestamp)
+ self.assertTrue(result)
+
+ def test_is_valid_timestamp_invalid(self):
+ """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
+ timestamps = (
+ ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
+ ("ew", "123"),
+ ("AoIKgA", "42076800"),
+ ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
+ ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
+ )
+
+ for timestamp, msg in timestamps:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.is_valid_timestamp(timestamp)
+ self.assertFalse(result)
+
+ def test_is_valid_hmac_valid(self):
+ """Should consider an HMAC valid if it has at least 3 unique characters."""
+ valid_hmacs = (
+ "VXmErH7j511turNpfURmb0rVNm8",
+ "Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "sJf6omBPORBPju3WJEIAcwW9Zds",
+ "s45jqDV_Iisn-symw0yDRrk_jf4",
+ )
+
+ for hmac in valid_hmacs:
+ with self.subTest(msg=hmac):
+ result = DiscordTokenFilter.is_maybe_valid_hmac(hmac)
+ self.assertTrue(result)
+
+ def test_is_invalid_hmac_invalid(self):
+ """Should consider an HMAC invalid if has fewer than 3 unique characters."""
+ invalid_hmacs = (
+ ("xxxxxxxxxxxxxxxxxx", "Single character"),
+ ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"),
+ ("ASFasfASFasfASFASsf", "Three characters alternating-case"),
+ ("asdasdasdasdasdasdasd", "Three characters one case"),
+ )
+
+ for hmac, msg in invalid_hmacs:
+ with self.subTest(msg=msg):
+ result = DiscordTokenFilter.is_maybe_valid_hmac(hmac)
+ self.assertFalse(result)
+
+ async def test_no_trigger_when_no_token(self):
+ """False should be returned if the message doesn't contain a Discord token."""
+ return_value = await self.filter.triggered_on(self.ctx)
+
+ self.assertFalse(return_value)
+
+ @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "Token")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE")
+ def test_find_token_valid_match(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`."""
+ matches = [
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ mock.create_autospec(Match, spec_set=True, instance=True),
+ ]
+ tokens = [
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ mock.create_autospec(Token, spec_set=True, instance=True),
+ ]
+
+ token_re.finditer.return_value = matches
+ token_cls.side_effect = tokens
+ extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid.
+ is_valid_timestamp.return_value = True
+ is_maybe_valid_hmac.return_value = True
+
+ return_value = DiscordTokenFilter.find_token_in_message(self.msg)
+
+ self.assertEqual(tokens[1], return_value)
+
+ @autospec(DiscordTokenFilter, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "Token")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "TOKEN_RE")
+ def test_find_token_invalid_matches(
+ self,
+ token_re,
+ token_cls,
+ extract_user_id,
+ is_valid_timestamp,
+ is_maybe_valid_hmac,
+ ):
+ """None should be returned if no matches have valid user IDs, HMACs, and timestamps."""
+ token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
+ token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
+ extract_user_id.return_value = None
+ is_valid_timestamp.return_value = False
+ is_maybe_valid_hmac.return_value = False
+
+ return_value = DiscordTokenFilter.find_token_in_message(self.msg)
+
+ self.assertIsNone(return_value)
+
+ def test_regex_invalid_tokens(self):
+ """Messages without anything looking like a token are not matched."""
+ tokens = (
+ "",
+ "lemon wins",
+ "..",
+ "x.y",
+ "x.y.",
+ ".y.z",
+ ".y.",
+ "..z",
+ "x..z",
+ " . . ",
+ "\n.\n.\n",
+ "hellö.world.bye",
+ "base64.nötbåse64.morebase64",
+ "19jd3J.dfkm3d.€víł§tüff",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = discord_token.TOKEN_RE.findall(token)
+ self.assertEqual(len(results), 0)
+
+ def test_regex_valid_tokens(self):
+ """Messages that look like tokens should be matched."""
+ # Don't worry, these tokens have been invalidated.
+ tokens = (
+ "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
+ "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
+ "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
+ "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
+ )
+
+ for token in tokens:
+ with self.subTest(token=token):
+ results = discord_token.TOKEN_RE.fullmatch(token)
+ self.assertIsNotNone(results, f"{token} was not matched by the regex")
+
+ def test_regex_matches_multiple_valid(self):
+ """Should support multiple matches in the middle of a string."""
+ token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
+ token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
+ message = f"garbage {token_1} hello {token_2} world"
+
+ results = discord_token.TOKEN_RE.finditer(message)
+ results = [match[0] for match in results]
+ self.assertCountEqual((token_1, token_2), results)
+
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "LOG_MESSAGE")
+ def test_format_log_message(self, log_message):
+ """Should correctly format the log message with info from the message and token."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ log_message.format.return_value = "Howdy"
+
+ return_value = DiscordTokenFilter.format_log_message(self.msg.author, self.msg.channel, token)
+
+ self.assertEqual(return_value, log_message.format.return_value)
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "UNKNOWN_USER_LOG_MESSAGE")
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "get_or_fetch_member")
+ async def test_format_userid_log_message_unknown(self, get_or_fetch_member, unknown_user_log_message):
+ """Should correctly format the user ID portion when the actual user it belongs to is unknown."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ unknown_user_log_message.format.return_value = " Partner"
+ get_or_fetch_member.return_value = None
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE")
+ async def test_format_userid_log_message_bot(self, known_user_log_message):
+ """Should correctly format the user ID portion when the ID belongs to a known bot."""
+ token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ known_user_log_message.format.return_value = " Partner"
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (known_user_log_message.format.return_value, True))
+
+ @patch("bot.instance", MockBot())
+ @autospec("bot.exts.filtering._filters.unique.discord_token", "KNOWN_USER_LOG_MESSAGE")
+ async def test_format_log_message_user_token_user(self, user_token_message):
+ """Should correctly format the user ID portion when the ID belongs to a known user."""
+ token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
+ user_token_message.format.return_value = "Partner"
+
+ return_value = await DiscordTokenFilter.format_userid_log_message(token)
+
+ self.assertEqual(return_value, (user_token_message.format.return_value, True))
diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py
new file mode 100644
index 000000000..827d267d2
--- /dev/null
+++ b/tests/bot/exts/filtering/test_extension_filter.py
@@ -0,0 +1,139 @@
+import unittest
+from unittest.mock import MagicMock, patch
+
+import arrow
+
+from bot.constants import Channels
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filter_lists import extension
+from bot.exts.filtering._filter_lists.extension import ExtensionsList
+from bot.exts.filtering._filter_lists.filter_list import ListType
+from tests.helpers import MockAttachment, MockBot, MockMember, MockMessage, MockTextChannel
+
+BOT = MockBot()
+
+
+class ExtensionsListTests(unittest.IsolatedAsyncioTestCase):
+ """Test the ExtensionsList class."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.filter_list = ExtensionsList(MagicMock())
+ now = arrow.utcnow().timestamp()
+ filters = []
+ self.whitelist = [".first", ".second", ".third"]
+ for i, filter_content in enumerate(self.whitelist, start=1):
+ filters.append({
+ "id": i, "content": filter_content, "description": None, "settings": {},
+ "additional_settings": {}, "created_at": now, "updated_at": now # noqa: P103
+ })
+ self.filter_list.add_list({
+ "id": 1,
+ "list_type": 1,
+ "created_at": now,
+ "updated_at": now,
+ "settings": {},
+ "filters": filters
+ })
+
+ self.message = MockMessage()
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", self.message)
+
+ @patch("bot.instance", BOT)
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should trigger the whitelist and result in no actions or messages."""
+ attachment = MockAttachment(filename="python.first")
+ ctx = self.ctx.replace(attachments=[attachment])
+
+ result = await self.filter_list.actions_for(ctx)
+
+ self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]}))
+
+ @patch("bot.instance", BOT)
+ async def test_message_without_attachment(self):
+ """Messages without attachments should return no triggers, messages, or actions."""
+ result = await self.filter_list.actions_for(self.ctx)
+
+ self.assertEqual(result, (None, [], {}))
+
+ @patch("bot.instance", BOT)
+ async def test_message_with_illegal_extension(self):
+ """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message."""
+ attachment = MockAttachment(filename="python.disallowed")
+ ctx = self.ctx.replace(attachments=[attachment])
+
+ result = await self.filter_list.actions_for(ctx)
+
+ self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []}))
+
+ @patch("bot.instance", BOT)
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site."""
+ attachment = MockAttachment(filename="python.py")
+ ctx = self.ctx.replace(attachments=[attachment])
+
+ await self.filter_list.actions_for(ctx)
+
+ self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION)
+
+ @patch("bot.instance", BOT)
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt/.json/.csv file should result in the correct embed."""
+ test_values = (
+ ("text", ".txt"),
+ ("json", ".json"),
+ ("csv", ".csv"),
+ )
+
+ for file_name, disallowed_extension in test_values:
+ with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
+
+ attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
+ ctx = self.ctx.replace(attachments=[attachment])
+
+ await self.filter_list.actions_for(ctx)
+
+ self.assertEqual(
+ ctx.dm_embed,
+ extension.TXT_EMBED_DESCRIPTION.format(
+ blocked_extension=disallowed_extension,
+ )
+ )
+
+ @patch("bot.instance", BOT)
+ async def test_other_disallowed_extension_embed_description(self):
+ """Test the description for a non .py/.txt/.json/.csv disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ ctx = self.ctx.replace(attachments=[attachment])
+
+ await self.filter_list.actions_for(ctx)
+ meta_channel = BOT.get_channel(Channels.meta)
+
+ self.assertEqual(
+ ctx.dm_embed,
+ extension.DISALLOWED_EMBED_DESCRIPTION.format(
+ joined_whitelist=", ".join(self.whitelist),
+ joined_blacklist=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+ )
+
+ @patch("bot.instance", BOT)
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (self.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], ["`.disallowed`"]),
+ ([".disallowed"], ["`.disallowed`"]),
+ ([".disallowed", ".illegal"], ["`.disallowed`", "`.illegal`"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions])
+ result = await self.filter_list.actions_for(ctx)
+ self.assertCountEqual(result[1], expected_disallowed_extensions)
diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py
new file mode 100644
index 000000000..5a289c1cf
--- /dev/null
+++ b/tests/bot/exts/filtering/test_settings.py
@@ -0,0 +1,20 @@
+import unittest
+
+import bot.exts.filtering._settings
+from bot.exts.filtering._settings import create_settings
+
+
+class FilterTests(unittest.TestCase):
+ """Test functionality of the Settings class and its subclasses."""
+
+ def test_create_settings_returns_none_for_empty_data(self):
+ """`create_settings` should return a tuple of two Nones when passed an empty dict."""
+ result = create_settings({})
+
+ self.assertEqual(result, (None, None))
+
+ def test_unrecognized_entry_makes_a_warning(self):
+ """When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`."""
+ create_settings({"abcd": {}})
+
+ self.assertIn("abcd", bot.exts.filtering._settings._already_warned)
diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py
new file mode 100644
index 000000000..3ae0b5ab5
--- /dev/null
+++ b/tests/bot/exts/filtering/test_settings_entries.py
@@ -0,0 +1,218 @@
+import unittest
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import (
+ Infraction, InfractionAndNotification, InfractionDuration
+)
+from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass
+from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope
+from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM
+from tests.helpers import MockCategoryChannel, MockDMChannel, MockMember, MockMessage, MockRole, MockTextChannel
+
+
+class FilterTests(unittest.TestCase):
+ """Test functionality of the Settings class and its subclasses."""
+
+ def setUp(self) -> None:
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ message = MockMessage(author=member, channel=channel)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message)
+
+ def test_role_bypass_is_off_for_user_without_roles(self):
+ """The role bypass should trigger when a user has no roles."""
+ member = MockMember()
+ self.ctx.author = member
+ bypass_entry = RoleBypass(bypass_roles=["123"])
+
+ result = bypass_entry.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_role_bypass_is_on_for_a_user_with_the_right_role(self):
+ """The role bypass should not trigger when the user has one of its roles."""
+ cases = (
+ ([123], ["123"]),
+ ([123, 234], ["123"]),
+ ([123], ["123", "234"]),
+ ([123, 234], ["123", "234"])
+ )
+
+ for user_role_ids, bypasses in cases:
+ with self.subTest(user_role_ids=user_role_ids, bypasses=bypasses):
+ user_roles = [MockRole(id=role_id) for role_id in user_role_ids]
+ member = MockMember(roles=user_roles)
+ self.ctx.author = member
+ bypass_entry = RoleBypass(bypass_roles=bypasses)
+
+ result = bypass_entry.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_for_empty_channel_scope(self):
+ """A filter is enabled for all channels by default."""
+ channel = MockTextChannel()
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_doesnt_trigger_for_disabled_channel(self):
+ """A filter shouldn't trigger if it's been disabled in the channel."""
+ channel = MockTextChannel(id=123)
+ scope = ChannelScope(
+ disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_in_disabled_category(self):
+ """A filter shouldn't trigger if it's been disabled in the category."""
+ channel = MockTextChannel(category=MockCategoryChannel(id=456))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=["456"], enabled_channels=None, enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_triggers_in_enabled_channel_in_disabled_category(self):
+ """A filter should trigger in an enabled channel even if it's been disabled in the category."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"], enabled_categories=None
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_triggers_inside_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["234"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertTrue(result)
+
+ def test_context_doesnt_trigger_outside_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=None, disabled_categories=None, enabled_channels=None, enabled_categories=["789"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_context_doesnt_trigger_inside_disabled_channel_in_enabled_category(self):
+ """A filter shouldn't trigger outside enabled categories, if there are any."""
+ channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
+ scope = ChannelScope(
+ disabled_channels=["123"], disabled_categories=None, enabled_channels=None, enabled_categories=["234"]
+ )
+ self.ctx.channel = channel
+
+ result = scope.triggers_on(self.ctx)
+
+ self.assertFalse(result)
+
+ def test_filtering_dms_when_necessary(self):
+ """A filter correctly ignores or triggers in a channel depending on the value of FilterDM."""
+ cases = (
+ (True, MockDMChannel(), True),
+ (False, MockDMChannel(), False),
+ (True, MockTextChannel(), True),
+ (False, MockTextChannel(), True)
+ )
+
+ for apply_in_dms, channel, expected in cases:
+ with self.subTest(apply_in_dms=apply_in_dms, channel=channel):
+ filter_dms = FilterDM(filter_dm=apply_in_dms)
+ self.ctx.channel = channel
+
+ result = filter_dms.triggers_on(self.ctx)
+
+ self.assertEqual(expected, result)
+
+ def test_infraction_merge_of_same_infraction_type(self):
+ """When both infractions are of the same type, the one with the longer duration wins."""
+ infraction1 = InfractionAndNotification(
+ infraction_type="TIMEOUT",
+ infraction_reason="hi",
+ infraction_duration=InfractionDuration(10),
+ dm_content="how",
+ dm_embed="what is",
+ infraction_channel=0
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="TIMEOUT",
+ infraction_reason="there",
+ infraction_duration=InfractionDuration(20),
+ dm_content="are you",
+ dm_embed="your name",
+ infraction_channel=0
+ )
+
+ result = infraction1.union(infraction2)
+
+ self.assertDictEqual(
+ result.dict(),
+ {
+ "infraction_type": Infraction.TIMEOUT,
+ "infraction_reason": "there",
+ "infraction_duration": InfractionDuration(20.0),
+ "dm_content": "are you",
+ "dm_embed": "your name",
+ "infraction_channel": 0
+ }
+ )
+
+ def test_infraction_merge_of_different_infraction_types(self):
+ """If there are two different infraction types, the one higher up the hierarchy should be picked."""
+ infraction1 = InfractionAndNotification(
+ infraction_type="TIMEOUT",
+ infraction_reason="hi",
+ infraction_duration=InfractionDuration(20),
+ dm_content="",
+ dm_embed="",
+ infraction_channel=0
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="BAN",
+ infraction_reason="",
+ infraction_duration=InfractionDuration(10),
+ dm_content="there",
+ dm_embed="",
+ infraction_channel=0
+ )
+
+ result = infraction1.union(infraction2)
+
+ self.assertDictEqual(
+ result.dict(),
+ {
+ "infraction_type": Infraction.BAN,
+ "infraction_reason": "",
+ "infraction_duration": InfractionDuration(10),
+ "dm_content": "there",
+ "dm_embed": "",
+ "infraction_channel": 0
+ }
+ )
diff --git a/tests/bot/exts/filtering/test_token_filter.py b/tests/bot/exts/filtering/test_token_filter.py
new file mode 100644
index 000000000..03fa6b4b9
--- /dev/null
+++ b/tests/bot/exts/filtering/test_token_filter.py
@@ -0,0 +1,49 @@
+import unittest
+
+import arrow
+
+from bot.exts.filtering._filter_context import Event, FilterContext
+from bot.exts.filtering._filters.token import TokenFilter
+from tests.helpers import MockMember, MockMessage, MockTextChannel
+
+
+class TokenFilterTests(unittest.IsolatedAsyncioTestCase):
+ """Test functionality of the token filter."""
+
+ def setUp(self) -> None:
+ member = MockMember(id=123)
+ channel = MockTextChannel(id=345)
+ message = MockMessage(author=member, channel=channel)
+ self.ctx = FilterContext(Event.MESSAGE, member, channel, "", message)
+
+ async def test_token_filter_triggers(self):
+ """The filter should evaluate to True only if its token is found in the context content."""
+ test_cases = (
+ (r"hi", "oh hi there", True),
+ (r"hi", "goodbye", False),
+ (r"bla\d{2,4}", "bla18", True),
+ (r"bla\d{2,4}", "bla1", False),
+ # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6
+ (r"TOKEN", "https://google.com TOKEN", True),
+ (r"TOKEN", "https://google.com something else", False)
+ )
+ now = arrow.utcnow().timestamp()
+
+ for pattern, content, expected in test_cases:
+ with self.subTest(
+ pattern=pattern,
+ content=content,
+ expected=expected,
+ ):
+ filter_ = TokenFilter({
+ "id": 1,
+ "content": pattern,
+ "description": None,
+ "settings": {},
+ "additional_settings": {},
+ "created_at": now,
+ "updated_at": now
+ })
+ self.ctx.content = content
+ result = await filter_.triggered_on(self.ctx)
+ self.assertEqual(result, expected)
diff --git a/tests/bot/exts/filters/test_antimalware.py b/tests/bot/exts/filters/test_antimalware.py
deleted file mode 100644
index 7282334e2..000000000
--- a/tests/bot/exts/filters/test_antimalware.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import unittest
-from unittest.mock import AsyncMock, Mock
-
-from discord import NotFound
-
-from bot.constants import Channels, STAFF_ROLES
-from bot.exts.filters import antimalware
-from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
-
-
-class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
- """Test the AntiMalware cog."""
-
- def setUp(self):
- """Sets up fresh objects for each test."""
- self.bot = MockBot()
- self.bot.filter_list_cache = {
- "FILE_FORMAT.True": {
- ".first": {},
- ".second": {},
- ".third": {},
- }
- }
- self.cog = antimalware.AntiMalware(self.bot)
- self.message = MockMessage()
- self.message.webhook_id = None
- self.message.author.bot = None
- self.whitelist = [".first", ".second", ".third"]
-
- async def test_message_with_allowed_attachment(self):
- """Messages with allowed extensions should not be deleted"""
- attachment = MockAttachment(filename="python.first")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
- self.message.delete.assert_not_called()
-
- async def test_message_without_attachment(self):
- """Messages without attachments should result in no action."""
- await self.cog.on_message(self.message)
- self.message.delete.assert_not_called()
-
- async def test_direct_message_with_attachment(self):
- """Direct messages should have no action taken."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.guild = None
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_webhook_message_with_illegal_extension(self):
- """A webhook message containing an illegal extension should be ignored."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.webhook_id = 697140105563078727
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_bot_message_with_illegal_extension(self):
- """A bot message containing an illegal extension should be ignored."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.author.bot = 409107086526644234
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_message_with_illegal_extension_gets_deleted(self):
- """A message containing an illegal extension should send an embed."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_called_once()
-
- async def test_message_send_by_staff(self):
- """A message send by a member of staff should be ignored."""
- staff_role = MockRole(id=STAFF_ROLES[0])
- self.message.author.roles.append(staff_role)
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- await self.cog.on_message(self.message)
-
- self.message.delete.assert_not_called()
-
- async def test_python_file_redirect_embed_description(self):
- """A message containing a .py file should result in an embed redirecting the user to our paste site"""
- attachment = MockAttachment(filename="python.py")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
-
- self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
-
- async def test_txt_file_redirect_embed_description(self):
- """A message containing a .txt/.json/.csv file should result in the correct embed."""
- test_values = (
- ("text", ".txt"),
- ("json", ".json"),
- ("csv", ".csv"),
- )
-
- for file_name, disallowed_extension in test_values:
- with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension):
-
- attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
- antimalware.TXT_EMBED_DESCRIPTION = Mock()
- antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
- cmd_channel = self.bot.get_channel(Channels.bot_commands)
-
- self.assertEqual(
- embed.description,
- antimalware.TXT_EMBED_DESCRIPTION.format.return_value
- )
- antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(
- blocked_extension=disallowed_extension,
- cmd_channel_mention=cmd_channel.mention
- )
-
- async def test_other_disallowed_extension_embed_description(self):
- """Test the description for a non .py/.txt/.json/.csv disallowed extension."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.channel.send = AsyncMock()
- antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
- antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
-
- await self.cog.on_message(self.message)
- self.message.channel.send.assert_called_once()
- args, kwargs = self.message.channel.send.call_args
- embed = kwargs.pop("embed")
- meta_channel = self.bot.get_channel(Channels.meta)
-
- self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
- antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
- joined_whitelist=", ".join(self.whitelist),
- blocked_extensions_str=".disallowed",
- meta_channel_mention=meta_channel.mention
- )
-
- async def test_removing_deleted_message_logs(self):
- """Removing an already deleted message logs the correct message"""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
- self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
-
- with self.assertLogs(logger=antimalware.log, level="INFO"):
- await self.cog.on_message(self.message)
- self.message.delete.assert_called_once()
-
- async def test_message_with_illegal_attachment_logs(self):
- """Deleting a message with an illegal attachment should result in a log."""
- attachment = MockAttachment(filename="python.disallowed")
- self.message.attachments = [attachment]
-
- with self.assertLogs(logger=antimalware.log, level="INFO"):
- await self.cog.on_message(self.message)
-
- async def test_get_disallowed_extensions(self):
- """The return value should include all non-whitelisted extensions."""
- test_values = (
- ([], []),
- (self.whitelist, []),
- ([".first"], []),
- ([".first", ".disallowed"], [".disallowed"]),
- ([".disallowed"], [".disallowed"]),
- ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
- )
-
- for extensions, expected_disallowed_extensions in test_values:
- with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
- self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
- disallowed_extensions = self.cog._get_disallowed_extensions(self.message)
- self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
-
-
-class AntiMalwareSetupTests(unittest.IsolatedAsyncioTestCase):
- """Tests setup of the `AntiMalware` cog."""
-
- async def test_setup(self):
- """Setup of the extension should call add_cog."""
- bot = MockBot()
- await antimalware.setup(bot)
- bot.add_cog.assert_awaited_once()
diff --git a/tests/bot/exts/filters/test_antispam.py b/tests/bot/exts/filters/test_antispam.py
deleted file mode 100644
index 6a0e4fded..000000000
--- a/tests/bot/exts/filters/test_antispam.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import unittest
-
-from bot.exts.filters import antispam
-
-
-class AntispamConfigurationValidationTests(unittest.TestCase):
- """Tests validation of the antispam cog configuration."""
-
- def test_default_antispam_config_is_valid(self):
- """The default antispam configuration is valid."""
- validation_errors = antispam.validate_config()
- self.assertEqual(validation_errors, {})
-
- def test_unknown_rule_returns_error(self):
- """Configuring an unknown rule returns an error."""
- self.assertEqual(
- antispam.validate_config({'invalid-rule': {}}),
- {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."}
- )
-
- def test_missing_keys_returns_error(self):
- """Not configuring required keys returns an error."""
- keys = (('interval', 'max'), ('max', 'interval'))
- for configured_key, unconfigured_key in keys:
- with self.subTest(
- configured_key=configured_key,
- unconfigured_key=unconfigured_key
- ):
- config = {'burst': {configured_key: 10}}
- error = f"Key `{unconfigured_key}` is required but not set for rule `burst`"
-
- self.assertEqual(
- antispam.validate_config(config),
- {'burst': error}
- )
diff --git a/tests/bot/exts/filters/test_filtering.py b/tests/bot/exts/filters/test_filtering.py
deleted file mode 100644
index e47cf627b..000000000
--- a/tests/bot/exts/filters/test_filtering.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import unittest
-from unittest.mock import patch
-
-from bot.exts.filters import filtering
-from tests.helpers import MockBot, autospec
-
-
-class FilteringCogTests(unittest.IsolatedAsyncioTestCase):
- """Tests the `Filtering` cog."""
-
- def setUp(self):
- """Instantiate the bot and cog."""
- self.bot = MockBot()
- with patch("pydis_core.utils.scheduling.create_task", new=lambda task, **_: task.close()):
- self.cog = filtering.Filtering(self.bot)
-
- @autospec(filtering.Filtering, "_get_filterlist_items", pass_mocks=False, return_value=["TOKEN"])
- async def test_token_filter(self):
- """Ensure that a filter token is correctly detected in a message."""
- messages = {
- "": False,
- "no matches": False,
- "TOKEN": True,
-
- # See advisory https://github.com/python-discord/bot/security/advisories/GHSA-j8c3-8x46-8pp6
- "https://google.com TOKEN": True,
- "https://google.com something else": False,
- }
-
- for message, match in messages.items():
- with self.subTest(input=message, match=match):
- result, _ = await self.cog._has_watch_regex_match(message)
-
- self.assertEqual(
- match,
- bool(result),
- msg=f"Hit was {'expected' if match else 'not expected'} for this input."
- )
- if result:
- self.assertEqual("TOKEN", result.group())
diff --git a/tests/bot/exts/filters/test_token_remover.py b/tests/bot/exts/filters/test_token_remover.py
deleted file mode 100644
index c1f3762ac..000000000
--- a/tests/bot/exts/filters/test_token_remover.py
+++ /dev/null
@@ -1,409 +0,0 @@
-import unittest
-from re import Match
-from unittest import mock
-from unittest.mock import MagicMock
-
-from discord import Colour, NotFound
-
-from bot import constants
-from bot.exts.filters import token_remover
-from bot.exts.filters.token_remover import Token, TokenRemover
-from bot.exts.moderation.modlog import ModLog
-from bot.utils.messages import format_user
-from tests.helpers import MockBot, MockMessage, autospec
-
-
-class TokenRemoverTests(unittest.IsolatedAsyncioTestCase):
- """Tests the `TokenRemover` cog."""
-
- def setUp(self):
- """Adds the cog, a bot, and a message to the instance for usage in tests."""
- self.bot = MockBot()
- self.cog = TokenRemover(bot=self.bot)
-
- self.msg = MockMessage(id=555, content="hello world")
- self.msg.channel.mention = "#lemonade-stand"
- self.msg.guild.get_member.return_value.bot = False
- self.msg.guild.get_member.return_value.__str__.return_value = "Woody"
- self.msg.author.__str__ = MagicMock(return_value=self.msg.author.name)
- self.msg.author.display_avatar.url = "picture-lemon.png"
-
- def test_extract_user_id_valid(self):
- """Should consider user IDs valid if they decode into an integer ID."""
- id_pairs = (
- ("NDcyMjY1OTQzMDYyNDEzMzMy", 472265943062413332),
- ("NDc1MDczNjI5Mzk5NTQ3OTA0", 475073629399547904),
- ("NDY3MjIzMjMwNjUwNzc3NjQx", 467223230650777641),
- )
-
- for token_id, user_id in id_pairs:
- with self.subTest(token_id=token_id):
- result = TokenRemover.extract_user_id(token_id)
- self.assertEqual(result, user_id)
-
- def test_extract_user_id_invalid(self):
- """Should consider non-digit and non-ASCII IDs invalid."""
- ids = (
- ("SGVsbG8gd29ybGQ", "non-digit ASCII"),
- ("0J_RgNC40LLQtdGCINC80LjRgA", "cyrillic text"),
- ("4pO14p6L4p6C4pG34p264pGl8J-EiOKSj-KCieKBsA", "Unicode digits"),
- ("4oaA4oaB4oWh4oWi4Lyz4Lyq4Lyr4LG9", "Unicode numerals"),
- ("8J2fjvCdn5nwnZ-k8J2fr_Cdn7rgravvvJngr6c", "Unicode decimals"),
- ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
- ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
- )
-
- for user_id, msg in ids:
- with self.subTest(msg=msg):
- result = TokenRemover.extract_user_id(user_id)
- self.assertIsNone(result)
-
- def test_is_valid_timestamp_valid(self):
- """Should consider timestamps valid if they're greater than the Discord epoch."""
- timestamps = (
- "XsyRkw",
- "Xrim9Q",
- "XsyR-w",
- "XsySD_",
- "Dn9r_A",
- )
-
- for timestamp in timestamps:
- with self.subTest(timestamp=timestamp):
- result = TokenRemover.is_valid_timestamp(timestamp)
- self.assertTrue(result)
-
- def test_is_valid_timestamp_invalid(self):
- """Should consider timestamps invalid if they're before Discord epoch or can't be parsed."""
- timestamps = (
- ("B4Yffw", "DISCORD_EPOCH - TOKEN_EPOCH - 1"),
- ("ew", "123"),
- ("AoIKgA", "42076800"),
- ("{hello}[world]&(bye!)", "ASCII invalid Base64"),
- ("Þíß-ï§-ňøẗ-våłìÐ", "Unicode invalid Base64"),
- )
-
- for timestamp, msg in timestamps:
- with self.subTest(msg=msg):
- result = TokenRemover.is_valid_timestamp(timestamp)
- self.assertFalse(result)
-
- def test_is_valid_hmac_valid(self):
- """Should consider an HMAC valid if it has at least 3 unique characters."""
- valid_hmacs = (
- "VXmErH7j511turNpfURmb0rVNm8",
- "Ysnu2wacjaKs7qnoo46S8Dm2us8",
- "sJf6omBPORBPju3WJEIAcwW9Zds",
- "s45jqDV_Iisn-symw0yDRrk_jf4",
- )
-
- for hmac in valid_hmacs:
- with self.subTest(msg=hmac):
- result = TokenRemover.is_maybe_valid_hmac(hmac)
- self.assertTrue(result)
-
- def test_is_invalid_hmac_invalid(self):
- """Should consider an HMAC invalid if has fewer than 3 unique characters."""
- invalid_hmacs = (
- ("xxxxxxxxxxxxxxxxxx", "Single character"),
- ("XxXxXxXxXxXxXxXxXx", "Single character alternating case"),
- ("ASFasfASFasfASFASsf", "Three characters alternating-case"),
- ("asdasdasdasdasdasdasd", "Three characters one case"),
- )
-
- for hmac, msg in invalid_hmacs:
- with self.subTest(msg=msg):
- result = TokenRemover.is_maybe_valid_hmac(hmac)
- self.assertFalse(result)
-
- def test_mod_log_property(self):
- """The `mod_log` property should ask the bot to return the `ModLog` cog."""
- self.bot.get_cog.return_value = 'lemon'
- self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)
- self.bot.get_cog.assert_called_once_with('ModLog')
-
- async def test_on_message_edit_uses_on_message(self):
- """The edit listener should delegate handling of the message to the normal listener."""
- self.cog.on_message = mock.create_autospec(self.cog.on_message, spec_set=True)
-
- await self.cog.on_message_edit(MockMessage(), self.msg)
- self.cog.on_message.assert_awaited_once_with(self.msg)
-
- @autospec(TokenRemover, "find_token_in_message", "take_action")
- async def test_on_message_takes_action(self, find_token_in_message, take_action):
- """Should take action if a valid token is found when a message is sent."""
- cog = TokenRemover(self.bot)
- found_token = "foobar"
- find_token_in_message.return_value = found_token
-
- await cog.on_message(self.msg)
-
- find_token_in_message.assert_called_once_with(self.msg)
- take_action.assert_awaited_once_with(cog, self.msg, found_token)
-
- @autospec(TokenRemover, "find_token_in_message", "take_action")
- async def test_on_message_skips_missing_token(self, find_token_in_message, take_action):
- """Shouldn't take action if a valid token isn't found when a message is sent."""
- cog = TokenRemover(self.bot)
- find_token_in_message.return_value = False
-
- await cog.on_message(self.msg)
-
- find_token_in_message.assert_called_once_with(self.msg)
- take_action.assert_not_awaited()
-
- @autospec(TokenRemover, "find_token_in_message")
- async def test_on_message_ignores_dms_bots(self, find_token_in_message):
- """Shouldn't parse a message if it is a DM or authored by a bot."""
- cog = TokenRemover(self.bot)
- dm_msg = MockMessage(guild=None)
- bot_msg = MockMessage(author=MagicMock(bot=True))
-
- for msg in (dm_msg, bot_msg):
- await cog.on_message(msg)
- find_token_in_message.assert_not_called()
-
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_no_matches(self, token_re):
- """None should be returned if the regex matches no tokens in a message."""
- token_re.finditer.return_value = ()
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertIsNone(return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
- @autospec("bot.exts.filters.token_remover", "Token")
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_valid_match(
- self,
- token_re,
- token_cls,
- extract_user_id,
- is_valid_timestamp,
- is_maybe_valid_hmac,
- ):
- """The first match with a valid user ID, timestamp, and HMAC should be returned as a `Token`."""
- matches = [
- mock.create_autospec(Match, spec_set=True, instance=True),
- mock.create_autospec(Match, spec_set=True, instance=True),
- ]
- tokens = [
- mock.create_autospec(Token, spec_set=True, instance=True),
- mock.create_autospec(Token, spec_set=True, instance=True),
- ]
-
- token_re.finditer.return_value = matches
- token_cls.side_effect = tokens
- extract_user_id.side_effect = (None, True) # The 1st match will be invalid, 2nd one valid.
- is_valid_timestamp.return_value = True
- is_maybe_valid_hmac.return_value = True
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertEqual(tokens[1], return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- @autospec(TokenRemover, "extract_user_id", "is_valid_timestamp", "is_maybe_valid_hmac")
- @autospec("bot.exts.filters.token_remover", "Token")
- @autospec("bot.exts.filters.token_remover", "TOKEN_RE")
- def test_find_token_invalid_matches(
- self,
- token_re,
- token_cls,
- extract_user_id,
- is_valid_timestamp,
- is_maybe_valid_hmac,
- ):
- """None should be returned if no matches have valid user IDs, HMACs, and timestamps."""
- token_re.finditer.return_value = [mock.create_autospec(Match, spec_set=True, instance=True)]
- token_cls.return_value = mock.create_autospec(Token, spec_set=True, instance=True)
- extract_user_id.return_value = None
- is_valid_timestamp.return_value = False
- is_maybe_valid_hmac.return_value = False
-
- return_value = TokenRemover.find_token_in_message(self.msg)
-
- self.assertIsNone(return_value)
- token_re.finditer.assert_called_once_with(self.msg.content)
-
- def test_regex_invalid_tokens(self):
- """Messages without anything looking like a token are not matched."""
- tokens = (
- "",
- "lemon wins",
- "..",
- "x.y",
- "x.y.",
- ".y.z",
- ".y.",
- "..z",
- "x..z",
- " . . ",
- "\n.\n.\n",
- "hellö.world.bye",
- "base64.nötbåse64.morebase64",
- "19jd3J.dfkm3d.€víł§tüff",
- )
-
- for token in tokens:
- with self.subTest(token=token):
- results = token_remover.TOKEN_RE.findall(token)
- self.assertEqual(len(results), 0)
-
- def test_regex_valid_tokens(self):
- """Messages that look like tokens should be matched."""
- # Don't worry, these tokens have been invalidated.
- tokens = (
- "NDcyMjY1OTQzMDYy_DEzMz-y.XsyRkw.VXmErH7j511turNpfURmb0rVNm8",
- "NDcyMjY1OTQzMDYyNDEzMzMy.Xrim9Q.Ysnu2wacjaKs7qnoo46S8Dm2us8",
- "NDc1MDczNjI5Mzk5NTQ3OTA0.XsyR-w.sJf6omBPORBPju3WJEIAcwW9Zds",
- "NDY3MjIzMjMwNjUwNzc3NjQx.XsySD_.s45jqDV_Iisn-symw0yDRrk_jf4",
- )
-
- for token in tokens:
- with self.subTest(token=token):
- results = token_remover.TOKEN_RE.fullmatch(token)
- self.assertIsNotNone(results, f"{token} was not matched by the regex")
-
- def test_regex_matches_multiple_valid(self):
- """Should support multiple matches in the middle of a string."""
- token_1 = "NDY3MjIzMjMwNjUwNzc3NjQx.XsyWGg.uFNEQPCc4ePwGh7egG8UicQssz8"
- token_2 = "NDcyMjY1OTQzMDYyNDEzMzMy.XsyWMw.l8XPnDqb0lp-EiQ2g_0xVFT1pyc"
- message = f"garbage {token_1} hello {token_2} world"
-
- results = token_remover.TOKEN_RE.finditer(message)
- results = [match[0] for match in results]
- self.assertCountEqual((token_1, token_2), results)
-
- @autospec("bot.exts.filters.token_remover", "LOG_MESSAGE")
- def test_format_log_message(self, log_message):
- """Should correctly format the log message with info from the message and token."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- log_message.format.return_value = "Howdy"
-
- return_value = TokenRemover.format_log_message(self.msg, token)
-
- self.assertEqual(return_value, log_message.format.return_value)
- log_message.format.assert_called_once_with(
- author=format_user(self.msg.author),
- channel=self.msg.channel.mention,
- user_id=token.user_id,
- timestamp=token.timestamp,
- hmac="xxxxxxxxxxxxxxxxxxxxxxxxjf4",
- )
-
- @autospec("bot.exts.filters.token_remover", "UNKNOWN_USER_LOG_MESSAGE")
- async def test_format_userid_log_message_unknown(self, unknown_user_log_message,):
- """Should correctly format the user ID portion when the actual user it belongs to is unknown."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- unknown_user_log_message.format.return_value = " Partner"
- msg = MockMessage(id=555, content="hello world")
- msg.guild.get_member.return_value = None
- msg.guild.fetch_member.side_effect = NotFound(mock.Mock(status=404), "Not found")
-
- return_value = await TokenRemover.format_userid_log_message(msg, token)
-
- self.assertEqual(return_value, (unknown_user_log_message.format.return_value, False))
- unknown_user_log_message.format.assert_called_once_with(user_id=472265943062413332)
-
- @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- async def test_format_userid_log_message_bot(self, known_user_log_message):
- """Should correctly format the user ID portion when the ID belongs to a known bot."""
- token = Token("NDcyMjY1OTQzMDYyNDEzMzMy", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- known_user_log_message.format.return_value = " Partner"
- msg = MockMessage(id=555, content="hello world")
- msg.guild.get_member.return_value.__str__.return_value = "Sam"
- msg.guild.get_member.return_value.bot = True
-
- return_value = await TokenRemover.format_userid_log_message(msg, token)
-
- self.assertEqual(return_value, (known_user_log_message.format.return_value, True))
-
- known_user_log_message.format.assert_called_once_with(
- user_id=472265943062413332,
- user_name="Sam",
- kind="BOT",
- )
-
- @autospec("bot.exts.filters.token_remover", "KNOWN_USER_LOG_MESSAGE")
- async def test_format_log_message_user_token_user(self, user_token_message):
- """Should correctly format the user ID portion when the ID belongs to a known user."""
- token = Token("NDY3MjIzMjMwNjUwNzc3NjQx", "XsySD_", "s45jqDV_Iisn-symw0yDRrk_jf4")
- user_token_message.format.return_value = "Partner"
-
- return_value = await TokenRemover.format_userid_log_message(self.msg, token)
-
- self.assertEqual(return_value, (user_token_message.format.return_value, True))
- user_token_message.format.assert_called_once_with(
- user_id=467223230650777641,
- user_name="Woody",
- kind="USER",
- )
-
- @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
- @autospec("bot.exts.filters.token_remover", "log")
- @autospec(TokenRemover, "format_log_message", "format_userid_log_message")
- async def test_take_action(self, format_log_message, format_userid_log_message, logger, mod_log_property):
- """Should delete the message and send a mod log."""
- cog = TokenRemover(self.bot)
- mod_log = mock.create_autospec(ModLog, spec_set=True, instance=True)
- token = mock.create_autospec(Token, spec_set=True, instance=True)
- token.user_id = "no-id"
- log_msg = "testing123"
- userid_log_message = "userid-log-message"
-
- mod_log_property.return_value = mod_log
- format_log_message.return_value = log_msg
- format_userid_log_message.return_value = (userid_log_message, True)
-
- await cog.take_action(self.msg, token)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_called_once_with(
- token_remover.DELETION_MESSAGE_TEMPLATE.format(mention=self.msg.author.mention)
- )
-
- format_log_message.assert_called_once_with(self.msg, token)
- format_userid_log_message.assert_called_once_with(self.msg, token)
- logger.debug.assert_called_with(log_msg)
- self.bot.stats.incr.assert_called_once_with("tokens.removed_tokens")
-
- mod_log.ignore.assert_called_once_with(constants.Event.message_delete, self.msg.id)
- mod_log.send_log_message.assert_called_once_with(
- icon_url=constants.Icons.token_removed,
- colour=Colour(constants.Colours.soft_red),
- title="Token removed!",
- text=log_msg + "\n" + userid_log_message,
- thumbnail=self.msg.author.display_avatar.url,
- channel_id=constants.Channels.mod_alerts,
- ping_everyone=True,
- )
-
- @mock.patch.object(TokenRemover, "mod_log", new_callable=mock.PropertyMock)
- async def test_take_action_delete_failure(self, mod_log_property):
- """Shouldn't send any messages if the token message can't be deleted."""
- cog = TokenRemover(self.bot)
- mod_log_property.return_value = mock.create_autospec(ModLog, spec_set=True, instance=True)
- self.msg.delete.side_effect = NotFound(MagicMock(), MagicMock())
-
- token = mock.create_autospec(Token, spec_set=True, instance=True)
- await cog.take_action(self.msg, token)
-
- self.msg.delete.assert_called_once_with()
- self.msg.channel.send.assert_not_awaited()
-
-
-class TokenRemoverExtensionTests(unittest.IsolatedAsyncioTestCase):
- """Tests for the token_remover extension."""
-
- @autospec("bot.exts.filters.token_remover", "TokenRemover")
- async def test_extension_setup(self, cog):
- """The TokenRemover cog should be added."""
- bot = MockBot()
- await token_remover.setup(bot)
-
- cog.assert_called_once_with(bot)
- bot.add_cog.assert_awaited_once()
- self.assertTrue(isinstance(bot.add_cog.call_args.args[0], TokenRemover))
diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py
index 9dcf7fd8c..79ac8ea2c 100644
--- a/tests/bot/exts/utils/snekbox/test_snekbox.py
+++ b/tests/bot/exts/utils/snekbox/test_snekbox.py
@@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.upload_output = AsyncMock() # Should not be called
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code('MyAwesomeCode')
@@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
@@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.upload_output = AsyncMock() # This function isn't called
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, []))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
@@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.cog.upload_output = AsyncMock() # This function isn't called
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False)
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.11")
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
deleted file mode 100644
index 0d570f5a3..000000000
--- a/tests/bot/rules/__init__.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import unittest
-from abc import ABCMeta, abstractmethod
-from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
-
-from tests.helpers import MockMessage
-
-
-class DisallowedCase(NamedTuple):
- """Encapsulation for test cases expected to fail."""
- recent_messages: List[MockMessage]
- culprits: Iterable[str]
- n_violations: int
-
-
-class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
- """
- Abstract class for antispam rule test cases.
-
- Tests for specific rules should inherit from `RuleTest` and implement
- `relevant_messages` and `get_report`. Each instance should also set the
- `apply` and `config` attributes as necessary.
-
- The execution of test cases can then be delegated to the `run_allowed`
- and `run_disallowed` methods.
- """
-
- apply: Callable # The tested rule's apply function
- config: Dict[str, int]
-
- async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
- """Run all `cases` against `self.apply` expecting them to pass."""
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config,
- ):
- self.assertIsNone(
- await self.apply(last_message, recent_messages, self.config)
- )
-
- async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
- """Run all `cases` against `self.apply` expecting them to fail."""
- for case in cases:
- recent_messages, culprits, n_violations = case
- last_message = recent_messages[0]
- relevant_messages = self.relevant_messages(case)
- desired_output = (
- self.get_report(case),
- culprits,
- relevant_messages,
- )
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- n_violations=n_violations,
- config=self.config,
- ):
- self.assertTupleEqual(
- await self.apply(last_message, recent_messages, self.config),
- desired_output,
- )
-
- @abstractmethod
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- """Give expected relevant messages for `case`."""
- raise NotImplementedError # pragma: no cover
-
- @abstractmethod
- def get_report(self, case: DisallowedCase) -> str:
- """Give expected error report for `case`."""
- raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
deleted file mode 100644
index d7e779221..000000000
--- a/tests/bot/rules/test_attachments.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from typing import Iterable
-
-from bot.rules import attachments
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, total_attachments: int) -> MockMessage:
- """Builds a message with `total_attachments` attachments."""
- return MockMessage(author=author, attachments=list(range(total_attachments)))
-
-
-class AttachmentRuleTests(RuleTest):
- """Tests applying the `attachments` antispam rule."""
-
- def setUp(self):
- self.apply = attachments.apply
- self.config = {"max": 5, "interval": 10}
-
- async def test_allows_messages_without_too_many_attachments(self):
- """Messages without too many attachments are allowed as-is."""
- cases = (
- [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
- [make_msg("bob", 2), make_msg("bob", 2)],
- [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_with_too_many_attachments(self):
- """Messages with too many attachments trigger the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
- ("bob",),
- 10,
- ),
- DisallowedCase(
- [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
- ("bob",),
- 6,
- ),
- DisallowedCase(
- [make_msg("alice", 6)],
- ("alice",),
- 6,
- ),
- DisallowedCase(
- [make_msg("alice", 1) for _ in range(6)],
- ("alice",),
- 6,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
deleted file mode 100644
index 03682966b..000000000
--- a/tests/bot/rules/test_burst.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from typing import Iterable
-
-from bot.rules import burst
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str) -> MockMessage:
- """
- Init a MockMessage instance with author set to `author`.
-
- This serves as a shorthand / alias to keep the test cases visually clean.
- """
- return MockMessage(author=author)
-
-
-class BurstRuleTests(RuleTest):
- """Tests the `burst` antispam rule."""
-
- def setUp(self):
- self.apply = burst.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("bob"), make_msg("bob")],
- [make_msg("bob"), make_msg("alice"), make_msg("bob")],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the amount of messages exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("bob")],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
- ("bob",),
- 3,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
deleted file mode 100644
index 3275143d5..000000000
--- a/tests/bot/rules/test_burst_shared.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Iterable
-
-from bot.rules import burst_shared
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str) -> MockMessage:
- """
- Init a MockMessage instance with the passed arg.
-
- This serves as a shorthand / alias to keep the test cases visually clean.
- """
- return MockMessage(author=author)
-
-
-class BurstSharedRuleTests(RuleTest):
- """Tests the `burst_shared` antispam rule."""
-
- def setUp(self):
- self.apply = burst_shared.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """
- Cases that do not violate the rule.
-
- There really isn't more to test here than a single case.
- """
- cases = (
- [make_msg("spongebob"), make_msg("patrick")],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the amount of messages exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("bob")],
- {"bob"},
- 3,
- ),
- DisallowedCase(
- [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
- {"bob", "alice"},
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return case.recent_messages
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
deleted file mode 100644
index f1e3c76a7..000000000
--- a/tests/bot/rules/test_chars.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from typing import Iterable
-
-from bot.rules import chars
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, n_chars: int) -> MockMessage:
- """Build a message with arbitrary content of `n_chars` length."""
- return MockMessage(author=author, content="A" * n_chars)
-
-
-class CharsRuleTests(RuleTest):
- """Tests the `chars` antispam rule."""
-
- def setUp(self):
- self.apply = chars.apply
- self.config = {
- "max": 20, # Max allowed sum of chars per user
- "interval": 10,
- }
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of chars within limit."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 20)],
- [make_msg("bob", 15), make_msg("alice", 15)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases where the total amount of chars exceeds the limit, triggering the rule."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 21)],
- ("bob",),
- 21,
- ),
- DisallowedCase(
- [make_msg("bob", 15), make_msg("bob", 15)],
- ("bob",),
- 30,
- ),
- DisallowedCase(
- [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
- ("alice",),
- 30,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
deleted file mode 100644
index 66c2d9f92..000000000
--- a/tests/bot/rules/test_discord_emojis.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from typing import Iterable
-
-from bot.rules import discord_emojis
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
-unicode_emoji = "🧪"
-
-
-def make_msg(author: str, n_emojis: int, emoji: str = discord_emoji) -> MockMessage:
- """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
- return MockMessage(author=author, content=emoji * n_emojis)
-
-
-class DiscordEmojisRuleTests(RuleTest):
- """Tests for the `discord_emojis` antispam rule."""
-
- def setUp(self):
- self.apply = discord_emojis.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of discord and unicode emojis within limit."""
- cases = (
- [make_msg("bob", 2)],
- [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
- [make_msg("bob", 2, unicode_emoji)],
- [
- make_msg("alice", 1, unicode_emoji),
- make_msg("bob", 2, unicode_emoji),
- make_msg("alice", 1, unicode_emoji)
- ],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with more than the allowed amount of discord and unicode emojis."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- ("alice",),
- 4,
- ),
- DisallowedCase(
- [make_msg("bob", 3, unicode_emoji)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [
- make_msg("alice", 2, unicode_emoji),
- make_msg("bob", 2, unicode_emoji),
- make_msg("alice", 2, unicode_emoji)
- ],
- ("alice",),
- 4
- )
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
deleted file mode 100644
index 9bd886a77..000000000
--- a/tests/bot/rules/test_duplicates.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from typing import Iterable
-
-from bot.rules import duplicates
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, content: str) -> MockMessage:
- """Give a MockMessage instance with `author` and `content` attrs."""
- return MockMessage(author=author, content=content)
-
-
-class DuplicatesRuleTests(RuleTest):
- """Tests the `duplicates` antispam rule."""
-
- def setUp(self):
- self.apply = duplicates.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("alice", "A"), make_msg("alice", "A")],
- [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
- [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with too many duplicate messages from the same author."""
- cases = (
- DisallowedCase(
- [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
- ("alice",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
- ("bob",),
- 3, # 4 duplicate messages, but only 3 from bob
- ),
- DisallowedCase(
- [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
- ("bob",),
- 3, # 4 message from bob, but only 3 duplicates
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if (
- msg.author == last_message.author
- and msg.content == last_message.content
- )
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
deleted file mode 100644
index b091bd9d7..000000000
--- a/tests/bot/rules/test_links.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from typing import Iterable
-
-from bot.rules import links
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, total_links: int) -> MockMessage:
- """Makes a message with `total_links` links."""
- content = " ".join(["https://pydis.com"] * total_links)
- return MockMessage(author=author, content=content)
-
-
-class LinksTests(RuleTest):
- """Tests applying the `links` rule."""
-
- def setUp(self):
- self.apply = links.apply
- self.config = {
- "max": 2,
- "interval": 10
- }
-
- async def test_links_within_limit(self):
- """Messages with an allowed amount of links."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 2)],
- [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
- )
-
- await self.run_allowed(cases)
-
- async def test_links_exceeding_limit(self):
- """Messages with a a higher than allowed amount of links."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 1), make_msg("bob", 2)],
- ("bob",),
- 3
- ),
- DisallowedCase(
- [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
- ("alice",),
- 3
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
- ("alice",),
- 3
- )
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
deleted file mode 100644
index e1f904917..000000000
--- a/tests/bot/rules/test_mentions.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from typing import Iterable, Optional
-
-import discord
-
-from bot.rules import mentions
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMember, MockMessage, MockMessageReference
-
-
-def make_msg(
- author: str,
- total_user_mentions: int,
- total_bot_mentions: int = 0,
- *,
- reference: Optional[MockMessageReference] = None
-) -> MockMessage:
- """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""
- user_mentions = [MockMember() for _ in range(total_user_mentions)]
- bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
-
- mentions = user_mentions + bot_mentions
- if reference is not None:
- # For the sake of these tests we assume that all references are mentions.
- mentions.append(reference.resolved.author)
- msg_type = discord.MessageType.reply
- else:
- msg_type = discord.MessageType.default
-
- return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)
-
-
-class TestMentions(RuleTest):
- """Tests applying the `mentions` antispam rule."""
-
- def setUp(self):
- self.apply = mentions.apply
- self.config = {
- "max": 2,
- "interval": 10,
- }
-
- async def test_mentions_within_limit(self):
- """Messages with an allowed amount of mentions."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 2)],
- [make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 1), make_msg("alice", 2)],
- )
-
- await self.run_allowed(cases)
-
- async def test_mentions_exceeding_limit(self):
- """Messages with a higher than allowed amount of mentions."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
- ("alice",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
- ("bob",),
- 4,
- ),
- DisallowedCase(
- [make_msg("bob", 3, 1)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference())],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))],
- ("bob",),
- 3
- )
- )
-
- await self.run_disallowed(cases)
-
- async def test_ignore_bot_mentions(self):
- """Messages with an allowed amount of mentions, also containing bot mentions."""
- cases = (
- [make_msg("bob", 0, 3)],
- [make_msg("bob", 2, 1)],
- [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
- [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
- )
-
- await self.run_allowed(cases)
-
- async def test_ignore_reply_mentions(self):
- """Messages with an allowed amount of mentions in the content, also containing reply mentions."""
- cases = (
- [
- make_msg("bob", 2, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True))
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference()),
- make_msg("bob", 0, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)),
- make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True))
- ]
- )
-
- await self.run_allowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
deleted file mode 100644
index e35377773..000000000
--- a/tests/bot/rules/test_newlines.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from typing import Iterable, List
-
-from bot.rules import newlines
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
- """Init a MockMessage instance with `author` and content configured by `newline_groups".
-
- Configure content by passing a list of ints, where each int `n` will generate
- a separate group of `n` newlines.
-
- Example:
- newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
- """
- content = " ".join("\n" * n for n in newline_groups)
- return MockMessage(author=author, content=content)
-
-
-class TotalNewlinesRuleTests(RuleTest):
- """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
-
- def setUp(self):
- self.apply = newlines.apply
- self.config = {
- "max": 5, # Max sum of newlines in relevant messages
- "max_consecutive": 3, # Max newlines in one group, in one message
- "interval": 10,
- }
-
- async def test_allows_messages_within_limit(self):
- """Cases which do not violate the rule."""
- cases = (
- [make_msg("alice", [])], # Single message with no newlines
- [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
- [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
- [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_total(self):
- """Cases which violate the rule by having too many newlines in total."""
- cases = (
- DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
- [make_msg("alice", [2, 2]), make_msg("alice", [2])],
- ("alice",),
- 6,
- ),
- DisallowedCase( # Here we test that only alice's newlines count in the sum
- [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
- ("alice",),
- 7,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_author = case.recent_messages[0].author
- return tuple(msg for msg in case.recent_messages if msg.author == last_author)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} newlines in {self.config['interval']}s"
-
-
-class GroupNewlinesRuleTests(RuleTest):
- """
- Tests the `newlines` antispam rule against max consecutive newline violations.
-
- As these violations yield a different error report, they require a different
- `get_report` implementation.
- """
-
- def setUp(self):
- self.apply = newlines.apply
- self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
-
- async def test_disallows_messages_consecutive(self):
- """Cases which violate the rule due to having too many consecutive newlines."""
- cases = (
- DisallowedCase( # Bob sends a group of newlines too large
- [make_msg("bob", [4])],
- ("bob",),
- 4,
- ),
- DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
- [make_msg("alice", [1]), make_msg("alice", [4])],
- ("alice",),
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_author = case.recent_messages[0].author
- return tuple(msg for msg in case.recent_messages if msg.author == last_author)
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
deleted file mode 100644
index 26c05d527..000000000
--- a/tests/bot/rules/test_role_mentions.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from typing import Iterable
-
-from bot.rules import role_mentions
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage
-
-
-def make_msg(author: str, n_mentions: int) -> MockMessage:
- """Build a MockMessage instance with `n_mentions` role mentions."""
- return MockMessage(author=author, role_mentions=[None] * n_mentions)
-
-
-class RoleMentionsRuleTests(RuleTest):
- """Tests for the `role_mentions` antispam rule."""
-
- def setUp(self):
- self.apply = role_mentions.apply
- self.config = {"max": 2, "interval": 10}
-
- async def test_allows_messages_within_limit(self):
- """Cases with a total amount of role mentions within limit."""
- cases = (
- [make_msg("bob", 2)],
- [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
- )
-
- await self.run_allowed(cases)
-
- async def test_disallows_messages_beyond_limit(self):
- """Cases with more than the allowed amount of role mentions."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
- ("alice",),
- 4,
- ),
- )
-
- await self.run_disallowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/helpers.py b/tests/helpers.py
index 1a71f210a..020f1aee5 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data
class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
- A MagicMock subclass to mock TextChannel objects.
+ A MagicMock subclass to mock DMChannel objects.
- Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ Instances of this class will follow the specifications of `discord.DMChannel` instances. For
more information, see the `MockGuild` docstring.
"""
spec_set = dm_channel_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))
@@ -423,7 +423,7 @@ category_channel_instance = discord.CategoryChannel(
class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id)}
- super().__init__(**collections.ChainMap(default_kwargs, kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
# Create a Message instance to get a realistic MagicMock of `discord.Message`