diff options
| -rw-r--r-- | bot/cogs/moderation/incidents.py | 18 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_incidents.py | 8 |
2 files changed, 15 insertions, 11 deletions
diff --git a/bot/cogs/moderation/incidents.py b/bot/cogs/moderation/incidents.py index 089a5bc9f..41a98bcb7 100644 --- a/bot/cogs/moderation/incidents.py +++ b/bot/cogs/moderation/incidents.py @@ -33,9 +33,11 @@ class Signal(Enum): INVESTIGATING = Emojis.incident_investigating -# Reactions from roles not listed here, or using emoji not listed here, will be removed +# Reactions from roles not listed here will be removed ALLOWED_ROLES: t.Set[int] = {Roles.moderators, Roles.admins, Roles.owners} -ALLOWED_EMOJI: t.Set[str] = {signal.value for signal in Signal} + +# Message must have all of these emoji to pass the `has_signals` check +ALL_SIGNALS: t.Set[str] = {signal.value for signal in Signal} def is_incident(message: discord.Message) -> bool: @@ -56,7 +58,7 @@ def own_reactions(message: discord.Message) -> t.Set[str]: def has_signals(message: discord.Message) -> bool: """True if `message` already has all `Signal` reactions, False otherwise.""" - return ALLOWED_EMOJI.issubset(own_reactions(message)) + return ALL_SIGNALS.issubset(own_reactions(message)) async def add_signals(incident: discord.Message) -> None: @@ -96,7 +98,9 @@ class Incidents(Cog): * See: `on_message` On reaction: - * Remove reaction if not permitted (`ALLOWED_EMOJI`, `ALLOWED_ROLES`) + * Remove reaction if not permitted + * User does not have any of the roles in `ALLOWED_ROLES` + * Used emoji is not a `Signal` member * If `Signal.ACTIONED` or `Signal.NOT_ACTIONED` were chosen, attempt to relay the incident message to #incidents-archive * If relay successful, delete original message @@ -217,13 +221,13 @@ class Incidents(Cog): await incident.remove_reaction(reaction, member) return - if reaction not in ALLOWED_EMOJI: + try: + signal = Signal(reaction) + except ValueError: log.debug(f"Removing invalid reaction: emoji {reaction} is not a valid signal") await incident.remove_reaction(reaction, member) return - # If we reach this point, we know that `emoji` is a `Signal` member - signal = Signal(reaction) log.trace(f"Received signal: {signal}") if signal not in (Signal.ACTIONED, Signal.NOT_ACTIONED): diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 55b15ec9e..862736785 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -131,17 +131,17 @@ class TestOwnReactions(unittest.TestCase): self.assertSetEqual(incidents.own_reactions(message), {"A", "B"}) -@patch("bot.cogs.moderation.incidents.ALLOWED_EMOJI", {"A", "B"}) +@patch("bot.cogs.moderation.incidents.ALL_SIGNALS", {"A", "B"}) class TestHasSignals(unittest.TestCase): """ Assertions for the `has_signals` function. - We patch `ALLOWED_EMOJI` globally. Each test function then patches `own_reactions` + We patch `ALL_SIGNALS` globally. Each test function then patches `own_reactions` as appropriate. """ def test_has_signals_true(self): - """True when `own_reactions` returns all emoji in `ALLOWED_EMOJI`.""" + """True when `own_reactions` returns all emoji in `ALL_SIGNALS`.""" message = MockMessage() own_reactions = MagicMock(return_value={"A", "B"}) @@ -149,7 +149,7 @@ class TestHasSignals(unittest.TestCase): self.assertTrue(incidents.has_signals(message)) def test_has_signals_false(self): - """False when `own_reactions` does not return all emoji in `ALLOWED_EMOJI`.""" + """False when `own_reactions` does not return all emoji in `ALL_SIGNALS`.""" message = MockMessage() own_reactions = MagicMock(return_value={"A", "C"}) |