aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/exts/filtering/_settings_types/validations/bypass_roles.py2
-rw-r--r--bot/exts/filtering/_settings_types/validations/filter_dm.py2
-rw-r--r--tests/bot/exts/filtering/test_settings.py2
-rw-r--r--tests/bot/exts/filtering/test_settings_entries.py182
-rw-r--r--tests/bot/rules/test_mentions.py131
-rw-r--r--tests/helpers.py6
6 files changed, 44 insertions, 281 deletions
diff --git a/bot/exts/filtering/_settings_types/validations/bypass_roles.py b/bot/exts/filtering/_settings_types/validations/bypass_roles.py
index a5c18cffc..c1e6f885d 100644
--- a/bot/exts/filtering/_settings_types/validations/bypass_roles.py
+++ b/bot/exts/filtering/_settings_types/validations/bypass_roles.py
@@ -15,7 +15,7 @@ class RoleBypass(ValidationEntry):
bypass_roles: set[Union[int, str]]
- @validator("bypass_roles", each_item=True)
+ @validator("bypass_roles", pre=True, each_item=True)
@classmethod
def maybe_cast_to_int(cls, role: str) -> Union[int, str]:
"""If the string is alphanumeric, cast it to int."""
diff --git a/bot/exts/filtering/_settings_types/validations/filter_dm.py b/bot/exts/filtering/_settings_types/validations/filter_dm.py
index 93022320f..b9e566253 100644
--- a/bot/exts/filtering/_settings_types/validations/filter_dm.py
+++ b/bot/exts/filtering/_settings_types/validations/filter_dm.py
@@ -14,4 +14,4 @@ class FilterDM(ValidationEntry):
def triggers_on(self, ctx: FilterContext) -> bool:
"""Return whether the filter should be triggered even if it was triggered in DMs."""
- return hasattr(ctx.channel, "guild") or self.filter_dm
+ return ctx.channel.guild is not None or self.filter_dm
diff --git a/tests/bot/exts/filtering/test_settings.py b/tests/bot/exts/filtering/test_settings.py
index ac21a5d47..5a289c1cf 100644
--- a/tests/bot/exts/filtering/test_settings.py
+++ b/tests/bot/exts/filtering/test_settings.py
@@ -11,7 +11,7 @@ class FilterTests(unittest.TestCase):
"""`create_settings` should return a tuple of two Nones when passed an empty dict."""
result = create_settings({})
- self.assertEquals(result, (None, None))
+ self.assertEqual(result, (None, None))
def test_unrecognized_entry_makes_a_warning(self):
"""When an unrecognized entry name is passed to `create_settings`, it should be added to `_already_warned`."""
diff --git a/tests/bot/exts/filtering/test_settings_entries.py b/tests/bot/exts/filtering/test_settings_entries.py
index 8dba5cb26..34b155d6b 100644
--- a/tests/bot/exts/filtering/test_settings_entries.py
+++ b/tests/bot/exts/filtering/test_settings_entries.py
@@ -1,9 +1,7 @@
import unittest
from bot.exts.filtering._filter_context import Event, FilterContext
-from bot.exts.filtering._settings_types.actions.infraction_and_notification import (
- Infraction, InfractionAndNotification, superstar
-)
+from bot.exts.filtering._settings_types.actions.infraction_and_notification import Infraction, InfractionAndNotification
from bot.exts.filtering._settings_types.validations.bypass_roles import RoleBypass
from bot.exts.filtering._settings_types.validations.channel_scope import ChannelScope
from bot.exts.filtering._settings_types.validations.filter_dm import FilterDM
@@ -23,7 +21,7 @@ class FilterTests(unittest.TestCase):
"""The role bypass should trigger when a user has no roles."""
member = MockMember()
self.ctx.author = member
- bypass_entry = RoleBypass(["123"])
+ bypass_entry = RoleBypass(bypass_roles=["123"])
result = bypass_entry.triggers_on(self.ctx)
@@ -43,7 +41,7 @@ class FilterTests(unittest.TestCase):
user_roles = [MockRole(id=role_id) for role_id in user_role_ids]
member = MockMember(roles=user_roles)
self.ctx.author = member
- bypass_entry = RoleBypass(bypasses)
+ bypass_entry = RoleBypass(bypass_roles=bypasses)
result = bypass_entry.triggers_on(self.ctx)
@@ -52,7 +50,7 @@ class FilterTests(unittest.TestCase):
def test_context_doesnt_trigger_for_empty_channel_scope(self):
"""A filter is enabled for all channels by default."""
channel = MockTextChannel()
- scope = ChannelScope({"disabled_channels": None, "disabled_categories": None, "enabled_channels": None})
+ scope = ChannelScope(disabled_channels=None, disabled_categories=None, enabled_channels=None)
self.ctx.channel = channel
result = scope.triggers_on(self.ctx)
@@ -62,7 +60,7 @@ class FilterTests(unittest.TestCase):
def test_context_doesnt_trigger_for_disabled_channel(self):
"""A filter shouldn't trigger if it's been disabled in the channel."""
channel = MockTextChannel(id=123)
- scope = ChannelScope({"disabled_channels": ["123"], "disabled_categories": None, "enabled_channels": None})
+ scope = ChannelScope(disabled_channels=["123"], disabled_categories=None, enabled_channels=None)
self.ctx.channel = channel
result = scope.triggers_on(self.ctx)
@@ -72,9 +70,7 @@ class FilterTests(unittest.TestCase):
def test_context_doesnt_trigger_in_disabled_category(self):
"""A filter shouldn't trigger if it's been disabled in the category."""
channel = MockTextChannel(category=MockCategoryChannel(id=456))
- scope = ChannelScope({
- "disabled_channels": None, "disabled_categories": ["456"], "enabled_channels": None
- })
+ scope = ChannelScope(disabled_channels=None, disabled_categories=["456"], enabled_channels=None)
self.ctx.channel = channel
result = scope.triggers_on(self.ctx)
@@ -84,7 +80,7 @@ class FilterTests(unittest.TestCase):
def test_context_triggers_in_enabled_channel_in_disabled_category(self):
"""A filter should trigger in an enabled channel even if it's been disabled in the category."""
channel = MockTextChannel(id=123, category=MockCategoryChannel(id=234))
- scope = ChannelScope({"disabled_channels": None, "disabled_categories": ["234"], "enabled_channels": ["123"]})
+ scope = ChannelScope(disabled_channels=None, disabled_categories=["234"], enabled_channels=["123"])
self.ctx.channel = channel
result = scope.triggers_on(self.ctx)
@@ -102,7 +98,7 @@ class FilterTests(unittest.TestCase):
for apply_in_dms, channel, expected in cases:
with self.subTest(apply_in_dms=apply_in_dms, channel=channel):
- filter_dms = FilterDM(apply_in_dms)
+ filter_dms = FilterDM(filter_dm=apply_in_dms)
self.ctx.channel = channel
result = filter_dms.triggers_on(self.ctx)
@@ -111,162 +107,60 @@ class FilterTests(unittest.TestCase):
def test_infraction_merge_of_same_infraction_type(self):
"""When both infractions are of the same type, the one with the longer duration wins."""
- infraction1 = InfractionAndNotification({
- "infraction_type": "mute",
- "infraction_reason": "hi",
- "infraction_duration": 10,
- "dm_content": "how",
- "dm_embed": "what is"
- })
- infraction2 = InfractionAndNotification({
- "infraction_type": "mute",
- "infraction_reason": "there",
- "infraction_duration": 20,
- "dm_content": "are you",
- "dm_embed": "your name"
- })
+ infraction1 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="hi",
+ infraction_duration=10,
+ dm_content="how",
+ dm_embed="what is"
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="there",
+ infraction_duration=20,
+ dm_content="are you",
+ dm_embed="your name"
+ )
result = infraction1 | infraction2
self.assertDictEqual(
- result.to_dict(),
+ result.dict(),
{
"infraction_type": Infraction.MUTE,
"infraction_reason": "there",
"infraction_duration": 20.0,
"dm_content": "are you",
"dm_embed": "your name",
- "_superstar": None
}
)
def test_infraction_merge_of_different_infraction_types(self):
"""If there are two different infraction types, the one higher up the hierarchy should be picked."""
- infraction1 = InfractionAndNotification({
- "infraction_type": "mute",
- "infraction_reason": "hi",
- "infraction_duration": 20,
- "dm_content": "",
- "dm_embed": ""
- })
- infraction2 = InfractionAndNotification({
- "infraction_type": "ban",
- "infraction_reason": "",
- "infraction_duration": 10,
- "dm_content": "there",
- "dm_embed": ""
- })
+ infraction1 = InfractionAndNotification(
+ infraction_type="MUTE",
+ infraction_reason="hi",
+ infraction_duration=20,
+ dm_content="",
+ dm_embed=""
+ )
+ infraction2 = InfractionAndNotification(
+ infraction_type="BAN",
+ infraction_reason="",
+ infraction_duration=10,
+ dm_content="there",
+ dm_embed=""
+ )
result = infraction1 | infraction2
self.assertDictEqual(
- result.to_dict(),
+ result.dict(),
{
"infraction_type": Infraction.BAN,
"infraction_reason": "",
"infraction_duration": 10.0,
"dm_content": "there",
"dm_embed": "",
- "_superstar": None
- }
- )
-
- def test_infraction_merge_with_a_superstar(self):
- """If there is a superstar infraction, it should be added to a separate field."""
- infraction1 = InfractionAndNotification({
- "infraction_type": "mute",
- "infraction_reason": "hi",
- "infraction_duration": 20,
- "dm_content": "there",
- "dm_embed": ""
- })
- infraction2 = InfractionAndNotification({
- "infraction_type": "superstar",
- "infraction_reason": "hello",
- "infraction_duration": 10,
- "dm_content": "you",
- "dm_embed": ""
- })
-
- result = infraction1 | infraction2
-
- self.assertDictEqual(
- result.to_dict(),
- {
- "infraction_type": Infraction.MUTE,
- "infraction_reason": "hi",
- "infraction_duration": 20.0,
- "dm_content": "there",
- "dm_embed": "",
- "_superstar": superstar("hello", 10.0)
- }
- )
-
- def test_merge_two_superstar_infractions(self):
- """When two superstar infractions are merged, the infraction type remains a superstar."""
- infraction1 = InfractionAndNotification({
- "infraction_type": "superstar",
- "infraction_reason": "hi",
- "infraction_duration": 20,
- "dm_content": "",
- "dm_embed": ""
- })
- infraction2 = InfractionAndNotification({
- "infraction_type": "superstar",
- "infraction_reason": "",
- "infraction_duration": 10,
- "dm_content": "there",
- "dm_embed": ""
- })
-
- result = infraction1 | infraction2
-
- self.assertDictEqual(
- result.to_dict(),
- {
- "infraction_type": Infraction.SUPERSTAR,
- "infraction_reason": "hi",
- "infraction_duration": 20.0,
- "dm_content": "",
- "dm_embed": "",
- "_superstar": None
- }
- )
-
- def test_merge_a_voiceban_and_a_superstar_with_another_superstar(self):
- """An infraction with a superstar merged with a superstar should combine under `_superstar`."""
- infraction1 = InfractionAndNotification({
- "infraction_type": "voice ban",
- "infraction_reason": "hi",
- "infraction_duration": 20,
- "dm_content": "hello",
- "dm_embed": ""
- })
- infraction2 = InfractionAndNotification({
- "infraction_type": "superstar",
- "infraction_reason": "bla",
- "infraction_duration": 10,
- "dm_content": "there",
- "dm_embed": ""
- })
- infraction3 = InfractionAndNotification({
- "infraction_type": "superstar",
- "infraction_reason": "blabla",
- "infraction_duration": 20,
- "dm_content": "there",
- "dm_embed": ""
- })
-
- result = infraction1 | infraction2 | infraction3
-
- self.assertDictEqual(
- result.to_dict(),
- {
- "infraction_type": Infraction.VOICE_BAN,
- "infraction_reason": "hi",
- "infraction_duration": 20,
- "dm_content": "hello",
- "dm_embed": "",
- "_superstar": superstar("blabla", 20)
}
)
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
deleted file mode 100644
index e1f904917..000000000
--- a/tests/bot/rules/test_mentions.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from typing import Iterable, Optional
-
-import discord
-
-from bot.rules import mentions
-from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMember, MockMessage, MockMessageReference
-
-
-def make_msg(
- author: str,
- total_user_mentions: int,
- total_bot_mentions: int = 0,
- *,
- reference: Optional[MockMessageReference] = None
-) -> MockMessage:
- """Makes a message from `author` with `total_user_mentions` user mentions and `total_bot_mentions` bot mentions."""
- user_mentions = [MockMember() for _ in range(total_user_mentions)]
- bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)]
-
- mentions = user_mentions + bot_mentions
- if reference is not None:
- # For the sake of these tests we assume that all references are mentions.
- mentions.append(reference.resolved.author)
- msg_type = discord.MessageType.reply
- else:
- msg_type = discord.MessageType.default
-
- return MockMessage(author=author, mentions=mentions, reference=reference, type=msg_type)
-
-
-class TestMentions(RuleTest):
- """Tests applying the `mentions` antispam rule."""
-
- def setUp(self):
- self.apply = mentions.apply
- self.config = {
- "max": 2,
- "interval": 10,
- }
-
- async def test_mentions_within_limit(self):
- """Messages with an allowed amount of mentions."""
- cases = (
- [make_msg("bob", 0)],
- [make_msg("bob", 2)],
- [make_msg("bob", 1), make_msg("bob", 1)],
- [make_msg("bob", 1), make_msg("alice", 2)],
- )
-
- await self.run_allowed(cases)
-
- async def test_mentions_exceeding_limit(self):
- """Messages with a higher than allowed amount of mentions."""
- cases = (
- DisallowedCase(
- [make_msg("bob", 3)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
- ("alice",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
- ("bob",),
- 4,
- ),
- DisallowedCase(
- [make_msg("bob", 3, 1)],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference())],
- ("bob",),
- 3,
- ),
- DisallowedCase(
- [make_msg("bob", 3, reference=MockMessageReference(reference_author_is_bot=True))],
- ("bob",),
- 3
- )
- )
-
- await self.run_disallowed(cases)
-
- async def test_ignore_bot_mentions(self):
- """Messages with an allowed amount of mentions, also containing bot mentions."""
- cases = (
- [make_msg("bob", 0, 3)],
- [make_msg("bob", 2, 1)],
- [make_msg("bob", 1, 2), make_msg("bob", 1, 2)],
- [make_msg("bob", 1, 5), make_msg("alice", 2, 5)]
- )
-
- await self.run_allowed(cases)
-
- async def test_ignore_reply_mentions(self):
- """Messages with an allowed amount of mentions in the content, also containing reply mentions."""
- cases = (
- [
- make_msg("bob", 2, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True))
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference()),
- make_msg("bob", 0, reference=MockMessageReference())
- ],
- [
- make_msg("bob", 2, reference=MockMessageReference(reference_author_is_bot=True)),
- make_msg("bob", 0, reference=MockMessageReference(reference_author_is_bot=True))
- ]
- )
-
- await self.run_allowed(cases)
-
- def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
- last_message = case.recent_messages[0]
- return tuple(
- msg
- for msg in case.recent_messages
- if msg.author == last_message.author
- )
-
- def get_report(self, case: DisallowedCase) -> str:
- return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/helpers.py b/tests/helpers.py
index 28a8e40a7..35a8a71f7 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -393,15 +393,15 @@ dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data
class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
"""
- A MagicMock subclass to mock TextChannel objects.
+ A MagicMock subclass to mock DMChannel objects.
- Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ Instances of this class will follow the specifications of `discord.DMChannel` instances. For
more information, see the `MockGuild` docstring.
"""
spec_set = dm_channel_instance
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()}
+ default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser(), 'guild': None}
super().__init__(**collections.ChainMap(kwargs, default_kwargs))