aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/moderation/__init__.py4
-rw-r--r--bot/cogs/moderation/incidents.py334
-rw-r--r--bot/constants.py6
-rw-r--r--config-default.yml7
-rw-r--r--tests/bot/cogs/moderation/test_incidents.py452
5 files changed, 801 insertions, 2 deletions
diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py
index 6880ca1bd..4455705f7 100644
--- a/bot/cogs/moderation/__init__.py
+++ b/bot/cogs/moderation/__init__.py
@@ -1,4 +1,5 @@
from bot.bot import Bot
+from .incidents import Incidents
from .infractions import Infractions
from .management import ModManagement
from .modlog import ModLog
@@ -7,7 +8,8 @@ from .superstarify import Superstarify
def setup(bot: Bot) -> None:
- """Load the Infractions, ModManagement, ModLog, Silence, and Superstarify cogs."""
+ """Load the Incidents, Infractions, ModManagement, ModLog, Silence, and Superstarify cogs."""
+ bot.add_cog(Incidents(bot))
bot.add_cog(Infractions(bot))
bot.add_cog(ModLog(bot))
bot.add_cog(ModManagement(bot))
diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py
new file mode 100644
index 000000000..16286bdab
--- /dev/null
+++ b/bot/cogs/moderation/incidents.py
@@ -0,0 +1,334 @@
+import asyncio
+import logging
+import typing as t
+from enum import Enum
+
+import discord
+from discord.ext.commands import Cog
+
+from bot.bot import Bot
+from bot.constants import Channels, Emojis, Roles, Webhooks
+
+log = logging.getLogger(__name__)
+
+
+class Signal(Enum):
+ """
+ Recognized incident status signals.
+
+ This binds emoji to actions. The bot will only react to emoji linked here.
+ All other signals are seen as invalid.
+ """
+
+ ACTIONED = Emojis.incident_actioned
+ NOT_ACTIONED = Emojis.incident_unactioned
+ INVESTIGATING = Emojis.incident_investigating
+
+
+# Reactions from roles not listed here, or using emoji not listed here, will be removed
+ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners}
+ALLOWED_EMOJI: t.Set[str] = {signal.value for signal in Signal}
+
+
+def is_incident(message: discord.Message) -> bool:
+ """True if `message` qualifies as an incident, False otherwise."""
+ conditions = (
+ message.channel.id == Channels.incidents, # Message sent in #incidents
+ not message.author.bot, # Not by a bot
+ not message.content.startswith("#"), # Doesn't start with a hash
+ not message.pinned, # And isn't header
+ )
+ return all(conditions)
+
+
+def own_reactions(message: discord.Message) -> t.Set[str]:
+ """Get the set of reactions placed on `message` by the bot itself."""
+ return {str(reaction.emoji) for reaction in message.reactions if reaction.me}
+
+
+def has_signals(message: discord.Message) -> bool:
+ """True if `message` already has all `Signal` reactions, False otherwise."""
+ missing_signals = ALLOWED_EMOJI - own_reactions(message) # In `ALLOWED_EMOJI` but not in `own_reactions(message)`
+ return not missing_signals
+
+
+async def add_signals(incident: discord.Message) -> None:
+ """
+ Add `Signal` member emoji to `incident` as reactions.
+
+ If the emoji has already been placed on `incident` by the bot, it will be skipped.
+ """
+ existing_reacts = own_reactions(incident)
+
+ for signal_emoji in Signal:
+
+ # This will not raise, but it is a superfluous API call that can be avoided
+ if signal_emoji.value in existing_reacts:
+ log.debug(f"Skipping emoji as it's already been placed: {signal_emoji}")
+
+ else:
+ log.debug(f"Adding reaction: {signal_emoji}")
+ await incident.add_reaction(signal_emoji.value)
+
+
+class Incidents(Cog):
+ """
+ Automation for the #incidents channel.
+
+ This cog does not provide a command API, it only reacts to the following events.
+
+ On start-up:
+ * Crawl #incidents and add missing `Signal` emoji where appropriate
+ * This is to retro-actively add the available options for messages which
+ were sent while the bot wasn't listening
+ * Pinned messages and message starting with # do not qualify as incidents
+ * See: `crawl_incidents`
+
+ On message:
+ * Add `Signal` member emoji if message qualifies as an incident
+ * Ignore messages starting with #
+ * Use this if verbal communication is necessary
+ * Each such message must be deleted manually once appropriate
+ * See: `on_message`
+
+ On reaction:
+ * Remove reaction if not permitted (`ALLOWED_EMOJI`, `ALLOWED_ROLES`)
+ * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to
+ relay the incident message to #incidents-archive
+ * If relay successful, delete original message
+ * See: `on_raw_reaction_add`
+
+ Please refer to function docstrings for implementation details.
+ """
+
+ def __init__(self, bot: Bot) -> None:
+ """Prepare `event_lock` and schedule `crawl_task` on start-up."""
+ self.bot = bot
+
+ self.event_lock = asyncio.Lock()
+ self.crawl_task = self.bot.loop.create_task(self.crawl_incidents())
+
+ async def crawl_incidents(self) -> None:
+ """
+ Crawl #incidents and add missing emoji where necessary.
+
+ This is to catch-up should an incident be reported while the bot wasn't listening.
+ After adding each reaction, we take a short break to avoid drowning in ratelimits.
+
+ Once this task is scheduled, listeners that change messages should await it.
+ The crawl assumes that the channel history doesn't change as we go over it.
+ """
+ await self.bot.wait_until_guild_available()
+ incidents: discord.TextChannel = self.bot.get_channel(Channels.incidents)
+
+ # Limit the query at 50 as in practice, there should never be this many messages,
+ # and if there are, something has likely gone very wrong
+ limit = 50
+
+ # Seconds to sleep after adding reactions to a message
+ sleep = 2
+
+ log.debug(f"Crawling messages in #incidents: {limit=}, {sleep=}")
+ async for message in incidents.history(limit=limit):
+
+ if not is_incident(message):
+ log.debug("Skipping message: not an incident")
+ continue
+
+ if has_signals(message):
+ log.debug("Skipping message: already has all signals")
+ continue
+
+ await add_signals(message)
+ await asyncio.sleep(sleep)
+
+ log.debug("Crawl task finished!")
+
+ async def archive(self, incident: discord.Message, outcome: Signal) -> bool:
+ """
+ Relay `incident` to the #incidents-archive channel.
+
+ The following pieces of information are relayed:
+ * Incident message content (clean, pingless)
+ * Incident author name (as webhook author)
+ * Incident author avatar (as webhook avatar)
+ * Resolution signal (`outcome`)
+
+ Return True if the relay finishes successfully. If anything goes wrong, meaning
+ not all information was relayed, return False. This signals that the original
+ message is not safe to be deleted, as we will lose some information.
+ """
+ log.debug(f"Archiving incident: {incident.id} with outcome: {outcome}")
+ try:
+ # First we try to grab the webhook
+ webhook: discord.Webhook = await self.bot.fetch_webhook(Webhooks.incidents_archive)
+
+ # Now relay the incident
+ message: discord.Message = await webhook.send(
+ content=incident.clean_content, # Clean content will prevent mentions from pinging
+ username=incident.author.name,
+ avatar_url=incident.author.avatar_url,
+ wait=True, # This makes the method return the sent Message object
+ )
+
+ # Finally add the `outcome` emoji
+ await message.add_reaction(outcome.value)
+
+ except Exception as exc:
+ log.exception("Failed to archive incident to #incidents-archive", exc_info=exc)
+ return False
+
+ else:
+ log.debug("Message archived successfully!")
+ return True
+
+ def make_confirmation_task(self, incident: discord.Message, timeout: int = 5) -> asyncio.Task:
+ """
+ Create a task to wait `timeout` seconds for `incident` to be deleted.
+
+ If `timeout` passes, this will raise `asyncio.TimeoutError`, signaling that we haven't
+ been able to confirm that the message was deleted.
+ """
+ log.debug(f"Confirmation task will wait {timeout=} seconds for {incident.id=} to be deleted")
+
+ def check(payload: discord.RawReactionActionEvent) -> bool:
+ return payload.message_id == incident.id
+
+ coroutine = self.bot.wait_for(event="raw_message_delete", check=check, timeout=timeout)
+ return self.bot.loop.create_task(coroutine)
+
+ async def process_event(self, reaction: str, incident: discord.Message, member: discord.Member) -> None:
+ """
+ Process a `reaction_add` event in #incidents.
+
+ First, we check that the reaction is a recognized `Signal` member, and that it was sent by
+ a permitted user (at least one role in `ALLOWED_ROLES`). If not, the reaction is removed.
+
+ If the reaction was either `Signal.ACTIONED` or `Signal.NOT_ACTIONED`, we attempt to relay
+ the report to #incidents-archive. If successful, the original message is deleted.
+
+ We do not release `event_lock` until we receive the corresponding `message_delete` event.
+ This ensures that if there is a racing event awaiting the lock, it will fail to find the
+ message, and will abort. There is a `timeout` to ensure that this doesn't hold the lock
+ forever should something go wrong.
+ """
+ members_roles: t.Set[int] = {role.id for role in member.roles}
+ if not members_roles & ALLOWED_ROLES: # Intersection is truthy on at least 1 common element
+ log.debug(f"Removing invalid reaction: user {member} is not permitted to send signals")
+ await incident.remove_reaction(reaction, member)
+ return
+
+ if reaction not in ALLOWED_EMOJI:
+ log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal")
+ await incident.remove_reaction(reaction, member)
+ return
+
+ # If we reach this point, we know that `emoji` is a `Signal` member
+ signal = Signal(reaction)
+ log.debug(f"Received signal: {signal}")
+
+ if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED):
+ log.debug("Reaction was valid, but no action is currently defined for it")
+ return
+
+ relay_successful = await self.archive(incident, signal)
+ if not relay_successful:
+ log.debug("Original message will not be deleted as we failed to relay it to the archive")
+ return
+
+ timeout = 5 # Seconds
+ confirmation_task = self.make_confirmation_task(incident, timeout)
+
+ log.debug("Deleting original message")
+ await incident.delete()
+
+ log.debug(f"Awaiting deletion confirmation: {timeout=} seconds")
+ try:
+ await confirmation_task
+ except asyncio.TimeoutError:
+ log.warning(f"Did not receive incident deletion confirmation within {timeout} seconds!")
+ else:
+ log.debug("Deletion was confirmed")
+
+ async def resolve_message(self, message_id: int) -> t.Optional[discord.Message]:
+ """
+ Get `discord.Message` for `message_id` from cache, or API.
+
+ We first look into the local cache to see if the message is present.
+
+ If not, we try to fetch the message from the API. This is necessary for messages
+ which were sent before the bot's current session.
+
+ In an edge-case, it is also possible that the message was already deleted, and
+ the API will respond with a 404. In such a case, None will be returned.
+ This signals that the event for `message_id` should be ignored.
+ """
+ await self.bot.wait_until_guild_available() # First make sure that the cache is ready
+ log.debug(f"Resolving message for: {message_id=}")
+ message: discord.Message = self.bot._connection._get_message(message_id) # noqa: Private attribute
+
+ if message is not None:
+ log.debug("Message was found in cache")
+ return message
+
+ log.debug("Message not found, attempting to fetch")
+ try:
+ message = await self.bot.get_channel(Channels.incidents).fetch_message(message_id)
+ except discord.NotFound:
+ log.debug("Message doesn't exist, it was likely already relayed")
+ except Exception as exc:
+ log.exception("Failed to fetch message!", exc_info=exc)
+ else:
+ log.debug("Message fetched successfully!")
+ return message
+
+ @Cog.listener()
+ async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None:
+ """
+ Pre-process `payload` and pass it to `process_event` if appropriate.
+
+ We abort instantly if `payload` doesn't relate to a message sent in #incidents,
+ or if it was sent by a bot.
+
+ If `payload` relates to a message in #incidents, we first ensure that `crawl_task` has
+ finished, to make sure we don't mutate channel state as we're crawling it.
+
+ Next, we acquire `event_lock` - to prevent racing, events are processed one at a time.
+
+ Once we have the lock, the `discord.Message` object for this event must be resolved.
+ If the lock was previously held by an event which successfully relayed the incident,
+ this will fail and we abort the current event.
+
+ Finally, with both the lock and the `discord.Message` instance in our hands, we delegate
+ to `process_event` to handle the event.
+
+ The justification for using a raw listener is the need to receive events for messages
+ which were not cached in the current session. As a result, a certain amount of
+ complexity is introduced, but at the moment this doesn't appear to be avoidable.
+ """
+ if payload.channel_id != Channels.incidents or payload.member.bot:
+ return
+
+ log.debug(f"Received reaction add event in #incidents, waiting for crawler: {self.crawl_task.done()=}")
+ await self.crawl_task
+
+ log.debug(f"Acquiring event lock: {self.event_lock.locked()=}")
+ async with self.event_lock:
+ message = await self.resolve_message(payload.message_id)
+
+ if message is None:
+ log.debug("Listener will abort as related message does not exist!")
+ return
+
+ if not is_incident(message):
+ log.debug("Ignoring event for a non-incident message")
+ return
+
+ await self.process_event(str(payload.emoji), message, payload.member)
+ log.debug("Releasing event lock")
+
+ @Cog.listener()
+ async def on_message(self, message: discord.Message) -> None:
+ """Pass `message` to `add_signals` if and only if it satisfies `is_incident`."""
+ if is_incident(message):
+ await add_signals(message)
diff --git a/bot/constants.py b/bot/constants.py
index a1b392c82..b3ef1660f 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -272,6 +272,10 @@ class Emojis(metaclass=YAMLGetter):
status_idle: str
status_dnd: str
+ incident_actioned: str
+ incident_unactioned: str
+ incident_investigating: str
+
failmail: str
trashcan: str
@@ -399,6 +403,7 @@ class Channels(metaclass=YAMLGetter):
helpers: int
how_to_get_help: int
incidents: int
+ incidents_archive: int
message_log: int
meta: int
mod_alerts: int
@@ -427,6 +432,7 @@ class Webhooks(metaclass=YAMLGetter):
reddit: int
duck_pond: int
dev_log: int
+ incidents_archive: int
class Roles(metaclass=YAMLGetter):
diff --git a/config-default.yml b/config-default.yml
index 64c4e715b..4c0196dc5 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -38,6 +38,10 @@ style:
status_dnd: "<:status_dnd:470326272082313216>"
status_offline: "<:status_offline:470326266537705472>"
+ incident_actioned: "<:incident_actioned:719645530128646266>"
+ incident_unactioned: "<:incident_unactioned:719645583245180960>"
+ incident_investigating: "<:incident_investigating:719645658671480924>"
+
failmail: "<:failmail:633660039931887616>"
trashcan: "<:trashcan:637136429717389331>"
@@ -173,6 +177,7 @@ guild:
organisation: &ORGANISATION 551789653284356126
staff_lounge: &STAFF_LOUNGE 464905259261755392
incidents: 714214212200562749
+ incidents_archive: 720668923636351037
# Voice
admins_voice: &ADMINS_VOICE 500734494840717332
@@ -251,7 +256,7 @@ guild:
duck_pond: 637821475327311927
dev_log: 680501655111729222
python_news: &PYNEWS_WEBHOOK 704381182279942324
-
+ incidents_archive: 720671599790915702
filter:
diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py
new file mode 100644
index 000000000..6158d5d20
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_incidents.py
@@ -0,0 +1,452 @@
+import asyncio
+import enum
+import logging
+import unittest
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+import aiohttp
+import discord
+
+from bot.cogs.moderation import Incidents, incidents
+from tests.helpers import (
+ MockAsyncWebhook,
+ MockBot,
+ MockMember,
+ MockMessage,
+ MockReaction,
+ MockRole,
+ MockTextChannel,
+ MockUser,
+)
+
+
+class MockSignal(enum.Enum):
+ A = "A"
+ B = "B"
+
+
+mock_404 = discord.NotFound(
+ response=MagicMock(aiohttp.ClientResponse), # Mock the erroneous response
+ message="Not found",
+)
+
+
+@patch("bot.constants.Channels.incidents", 123)
+class TestIsIncident(unittest.TestCase):
+ """
+ Collection of tests for the `is_incident` helper function.
+
+ In `setUp`, we will create a mock message which should qualify as an incident. Each
+ test case will then mutate this instance to make it **not** qualify, in various ways.
+
+ Notice that we patch the #incidents channel id globally for this class.
+ """
+
+ def setUp(self) -> None:
+ """Prepare a mock message which should qualify as an incident."""
+ self.incident = MockMessage(
+ channel=MockTextChannel(id=123),
+ content="this is an incident",
+ author=MockUser(bot=False),
+ pinned=False,
+ )
+
+ def test_is_incident_true(self):
+ """Message qualifies as an incident if unchanged."""
+ self.assertTrue(incidents.is_incident(self.incident))
+
+ def check_false(self):
+ """Assert that `self.incident` does **not** qualify as an incident."""
+ self.assertFalse(incidents.is_incident(self.incident))
+
+ def test_is_incident_false_channel(self):
+ """Message doesn't qualify if sent outside of #incidents."""
+ self.incident.channel = MockTextChannel(id=456)
+ self.check_false()
+
+ def test_is_incident_false_content(self):
+ """Message doesn't qualify if content begins with hash symbol."""
+ self.incident.content = "# this is a comment message"
+ self.check_false()
+
+ def test_is_incident_false_author(self):
+ """Message doesn't qualify if author is a bot."""
+ self.incident.author = MockUser(bot=True)
+ self.check_false()
+
+ def test_is_incident_false_pinned(self):
+ """Message doesn't qualify if it is pinned."""
+ self.incident.pinned = True
+ self.check_false()
+
+
+class TestOwnReactions(unittest.TestCase):
+ """Assertions for the `own_reactions` function."""
+
+ def test_own_reactions(self):
+ """Only bot's own emoji are extracted from the input incident."""
+ reactions = (
+ MockReaction(emoji="A", me=True),
+ MockReaction(emoji="B", me=True),
+ MockReaction(emoji="C", me=False),
+ )
+ message = MockMessage(reactions=reactions)
+ self.assertSetEqual(incidents.own_reactions(message), {"A", "B"})
+
+
+@patch("bot.cogs.moderation.incidents.ALLOWED_EMOJI", {"A", "B"})
+class TestHasSignals(unittest.TestCase):
+ """
+ Assertions for the `has_signals` function.
+
+ We patch `ALLOWED_EMOJI` globally. Each test function then patches `own_reactions`
+ as appropriate.
+ """
+
+ def test_has_signals_true(self):
+ """True when `own_reactions` returns all emoji in `ALLOWED_EMOJI`."""
+ message = MockMessage()
+ own_reactions = MagicMock(return_value={"A", "B"})
+
+ with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions):
+ self.assertTrue(incidents.has_signals(message))
+
+ def test_has_signals_false(self):
+ """False when `own_reactions` does not return all emoji in `ALLOWED_EMOJI`."""
+ message = MockMessage()
+ own_reactions = MagicMock(return_value={"A", "C"})
+
+ with patch("bot.cogs.moderation.incidents.own_reactions", own_reactions):
+ self.assertFalse(incidents.has_signals(message))
+
+
+@patch("bot.cogs.moderation.incidents.Signal", MockSignal)
+class TestAddSignals(unittest.IsolatedAsyncioTestCase):
+ """
+ Assertions for the `add_signals` coroutine.
+
+ These are all fairly similar and could go into a single test function, but I found the
+ patching & sub-testing fairly awkward in that case and decided to split them up
+ to avoid unnecessary syntax noise.
+ """
+
+ def setUp(self):
+ """Prepare a mock incident message for tests to use."""
+ self.incident = MockMessage()
+
+ @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value=set()))
+ async def test_add_signals_missing(self):
+ """All emoji are added when none are present."""
+ await incidents.add_signals(self.incident)
+ self.incident.add_reaction.assert_has_calls([call("A"), call("B")])
+
+ @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A"}))
+ async def test_add_signals_partial(self):
+ """Only missing emoji are added when some are present."""
+ await incidents.add_signals(self.incident)
+ self.incident.add_reaction.assert_has_calls([call("B")])
+
+ @patch("bot.cogs.moderation.incidents.own_reactions", MagicMock(return_value={"A", "B"}))
+ async def test_add_signals_present(self):
+ """No emoji are added when all are present."""
+ await incidents.add_signals(self.incident)
+ self.incident.add_reaction.assert_not_called()
+
+
+class TestIncidents(unittest.IsolatedAsyncioTestCase):
+ """
+ Tests for bound methods of the `Incidents` cog.
+
+ Use this as a base class for `Incidents` tests - it will prepare a fresh instance
+ for each test function, but not make any assertions on its own. Tests can mutate
+ the instance as they wish.
+ """
+
+ def setUp(self):
+ """
+ Prepare a fresh `Incidents` instance for each test.
+
+ Note that this will not schedule `crawl_incidents` in the background, as everything
+ is being mocked. The `crawl_task` attribute will end up being None.
+ """
+ self.cog_instance = Incidents(MockBot())
+
+
+class TestArchive(TestIncidents):
+ """Tests for the `Incidents.archive` coroutine."""
+
+ async def test_archive_webhook_not_found(self):
+ """
+ Method recovers and returns False when the webhook is not found.
+
+ Implicitly, this also tests that the error is handled internally and doesn't
+ propagate out of the method, which is just as important.
+ """
+ self.cog_instance.bot.fetch_webhook = AsyncMock(side_effect=mock_404)
+ self.assertFalse(await self.cog_instance.archive(incident=MockMessage(), outcome=MagicMock()))
+
+ async def test_archive_relays_incident(self):
+ """
+ If webhook is found, method relays `incident` properly.
+
+ This test will assert the following:
+ * The fetched webhook's `send` method is fed the correct arguments
+ * The message returned by `send` will have `outcome` reaction added
+ * Finally, the `archive` method returns True
+
+ Assertions are made specifically in this order.
+ """
+ webhook_message = MockMessage() # The message that will be returned by the webhook's `send` method
+ webhook = MockAsyncWebhook(send=AsyncMock(return_value=webhook_message))
+
+ self.cog_instance.bot.fetch_webhook = AsyncMock(return_value=webhook) # Patch in our webhook
+
+ # Now we'll pas our own `incident` to `archive` and capture the return value
+ incident = MockMessage(
+ clean_content="pingless message",
+ content="pingful message",
+ author=MockUser(name="author_name", avatar_url="author_avatar"),
+ id=123,
+ )
+ archive_return = await self.cog_instance.archive(incident, outcome=MagicMock(value="A"))
+
+ # Check that the webhook was dispatched correctly
+ webhook.send.assert_called_once_with(
+ content="pingless message",
+ username="author_name",
+ avatar_url="author_avatar",
+ wait=True,
+ )
+
+ # Now check that the correct emoji was added to the relayed message
+ webhook_message.add_reaction.assert_called_once_with("A")
+
+ # Finally check that the method returned True
+ self.assertTrue(archive_return)
+
+
+class TestMakeConfirmationTask(TestIncidents):
+ """
+ Tests for the `Incidents.make_confirmation_task` method.
+
+ Writing tests for this method is difficult, as it mostly just delegates the provided
+ information elsewhere. There is very little internal logic. Whether our approach
+ works conceptually is difficult to prove using unit tests.
+ """
+
+ def test_make_confirmation_task_check(self):
+ """
+ The internal check will recognize the passed incident.
+
+ This is a little tricky - we first pass a message with a specific `id` in, and then
+ retrieve the built check from the `call_args` of the `wait_for` method. This relies
+ on the check being passed as a kwarg.
+
+ Once the check is retrieved, we assert that it gives True for our incident's `id`,
+ and False for any other.
+
+ If this function begins to fail, first check that `created_check` is being retrieved
+ correctly. It should be the function that is built locally in the tested method.
+ """
+ self.cog_instance.make_confirmation_task(MockMessage(id=123))
+
+ self.cog_instance.bot.wait_for.assert_called_once()
+ created_check = self.cog_instance.bot.wait_for.call_args.kwargs["check"]
+
+ # The `message_id` matches the `id` of our incident
+ self.assertTrue(created_check(payload=MagicMock(message_id=123)))
+
+ # This `message_id` does not match
+ self.assertFalse(created_check(payload=MagicMock(message_id=0)))
+
+
+@patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2})
+@patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", AsyncMock()) # Generic awaitable
+class TestProcessEvent(TestIncidents):
+ """Tests for the `Incidents.process_event` coroutine."""
+
+ @patch("bot.cogs.moderation.incidents.ALLOWED_ROLES", {1, 2})
+ async def test_process_event_bad_role(self):
+ """The reaction is removed when the author lacks all allowed roles."""
+ incident = MockMessage()
+ member = MockMember(roles=[MockRole(id=0)]) # Must have role 1 or 2
+
+ await self.cog_instance.process_event("reaction", incident, member)
+ incident.remove_reaction.assert_called_once_with("reaction", member)
+
+ async def test_process_event_bad_emoji(self):
+ """
+ The reaction is removed when an invalid emoji is used.
+
+ This requires that we pass in a `member` with valid roles, as we need the role check
+ to succeed.
+ """
+ incident = MockMessage()
+ member = MockMember(roles=[MockRole(id=1)]) # Member has allowed role
+
+ await self.cog_instance.process_event("invalid_signal", incident, member)
+ incident.remove_reaction.assert_called_once_with("invalid_signal", member)
+
+ async def test_process_event_no_archive_on_investigating(self):
+ """Message is not archived on `Signal.INVESTIGATING`."""
+ with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock()) as mocked_archive:
+ await self.cog_instance.process_event(
+ reaction=incidents.Signal.INVESTIGATING.value,
+ incident=MockMessage(),
+ member=MockMember(roles=[MockRole(id=1)]),
+ )
+
+ mocked_archive.assert_not_called()
+
+ async def test_process_event_no_delete_if_archive_fails(self):
+ """
+ Original message is not deleted when `Incidents.archive` returns False.
+
+ This is the way of signaling that the relay failed, and we should not remove the original,
+ as that would result in losing the incident record.
+ """
+ incident = MockMessage()
+
+ with patch("bot.cogs.moderation.incidents.Incidents.archive", AsyncMock(return_value=False)):
+ await self.cog_instance.process_event(
+ reaction=incidents.Signal.ACTIONED.value,
+ incident=incident,
+ member=MockMember(roles=[MockRole(id=1)])
+ )
+
+ incident.delete.assert_not_called()
+
+ async def test_process_event_confirmation_task_is_awaited(self):
+ """Task given by `Incidents.make_confirmation_task` is awaited before method exits."""
+ mock_task = AsyncMock()
+
+ with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task):
+ await self.cog_instance.process_event(
+ reaction=incidents.Signal.ACTIONED.value,
+ incident=MockMessage(),
+ member=MockMember(roles=[MockRole(id=1)])
+ )
+
+ mock_task.assert_awaited()
+
+ async def test_process_event_confirmation_task_timeout_is_handled(self):
+ """
+ Confirmation task `asyncio.TimeoutError` is handled gracefully.
+
+ We have `make_confirmation_task` return a mock with a side effect, and then catch the
+ exception should it propagate out of `process_event`. This is so that we can then manually
+ fail the test with a more informative message than just the plain traceback.
+ """
+ mock_task = AsyncMock(side_effect=asyncio.TimeoutError())
+
+ try:
+ with patch("bot.cogs.moderation.incidents.Incidents.make_confirmation_task", mock_task):
+ await self.cog_instance.process_event(
+ reaction=incidents.Signal.ACTIONED.value,
+ incident=MockMessage(),
+ member=MockMember(roles=[MockRole(id=1)])
+ )
+ except asyncio.TimeoutError:
+ self.fail("TimeoutError was not handled gracefully, and propagated out of `process_event`!")
+
+
+class TestResolveMessage(TestIncidents):
+ """Tests for the `Incidents.resolve_message` coroutine."""
+
+ async def test_resolve_message_pass_message_id(self):
+ """Method will call `_get_message` with the passed `message_id`."""
+ await self.cog_instance.resolve_message(123)
+ self.cog_instance.bot._connection._get_message.assert_called_once_with(123)
+
+ async def test_resolve_message_in_cache(self):
+ """
+ No API call is made if the queried message exists in the cache.
+
+ We mock the `_get_message` return value regardless of input. Whether it finds the message
+ internally is considered d.py's responsibility, not ours.
+ """
+ cached_message = MockMessage(id=123)
+ self.cog_instance.bot._connection._get_message = MagicMock(return_value=cached_message)
+
+ return_value = await self.cog_instance.resolve_message(123)
+
+ self.assertIs(return_value, cached_message)
+ self.cog_instance.bot.get_channel.assert_not_called() # The `fetch_message` line was never hit
+
+ async def test_resolve_message_not_in_cache(self):
+ """
+ The message is retrieved from the API if it isn't cached.
+
+ This is desired behaviour for messages which exist, but were sent before the bot's
+ current session.
+ """
+ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None
+
+ # API returns our message
+ uncached_message = MockMessage()
+ fetch_message = AsyncMock(return_value=uncached_message)
+ self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message))
+
+ retrieved_message = await self.cog_instance.resolve_message(123)
+ self.assertIs(retrieved_message, uncached_message)
+
+ async def test_resolve_message_doesnt_exist(self):
+ """
+ If the API returns a 404, the function handles it gracefully and returns None.
+
+ This is an edge-case happening with racing events - event A will relay the message
+ to the archive and delete the original. Once event B acquires the `event_lock`,
+ it will not find the message in the cache, and will ask the API.
+ """
+ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None
+
+ fetch_message = AsyncMock(side_effect=mock_404)
+ self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message))
+
+ self.assertIsNone(await self.cog_instance.resolve_message(123))
+
+ async def test_resolve_message_fetch_fails(self):
+ """
+ Non-404 errors are handled, logged & None is returned.
+
+ In contrast with a 404, this should make an error-level log. We assert that at least
+ one such log was made - we do not make any assertions about the log's message.
+ """
+ self.cog_instance.bot._connection._get_message = MagicMock(return_value=None) # Cache returns None
+
+ arbitrary_error = discord.HTTPException(
+ response=MagicMock(aiohttp.ClientResponse),
+ message="Arbitrary error",
+ )
+ fetch_message = AsyncMock(side_effect=arbitrary_error)
+ self.cog_instance.bot.get_channel = MagicMock(return_value=MockTextChannel(fetch_message=fetch_message))
+
+ with self.assertLogs(logger=incidents.log, level=logging.ERROR):
+ self.assertIsNone(await self.cog_instance.resolve_message(123))
+
+
+class TestOnMessage(TestIncidents):
+ """
+ Tests for the `Incidents.on_message` listener.
+
+ Notice the decorators mocking the `is_incident` return value. The `is_incidents`
+ function is tested in `TestIsIncident` - here we do not worry about it.
+ """
+
+ @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=True))
+ async def test_on_message_incident(self):
+ """Messages qualifying as incidents are passed to `add_signals`."""
+ incident = MockMessage()
+
+ with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals:
+ await self.cog_instance.on_message(incident)
+
+ mock_add_signals.assert_called_once_with(incident)
+
+ @patch("bot.cogs.moderation.incidents.is_incident", MagicMock(return_value=False))
+ async def test_on_message_non_incident(self):
+ """Messages not qualifying as incidents are ignored."""
+ with patch("bot.cogs.moderation.incidents.add_signals", AsyncMock()) as mock_add_signals:
+ await self.cog_instance.on_message(MockMessage())
+
+ mock_add_signals.assert_not_called()