diff options
| -rw-r--r-- | bot/exts/moderation/infraction/_scheduler.py | 27 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/infractions.py | 15 | ||||
| -rw-r--r-- | bot/exts/moderation/infraction/superstarify.py | 11 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 30 | 
4 files changed, 52 insertions, 31 deletions
| diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 655290559..bd7c4d6f2 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -1,6 +1,7 @@  import textwrap  import typing as t  from abc import abstractmethod +from collections.abc import Awaitable, Callable  from gettext import ngettext  import arrow @@ -79,9 +80,14 @@ class InfractionScheduler:      async def reapply_infraction(          self,          infraction: _utils.Infraction, -        apply_coro: t.Optional[t.Awaitable] +        action: t.Optional[Callable[[], Awaitable[None]]]      ) -> None: -        """Reapply an infraction if it's still active or deactivate it if less than 60 sec left.""" +        """ +        Reapply an infraction if it's still active or deactivate it if less than 60 sec left. + +        Note: The `action` provided is an async function rather than a coroutine +        to prevent getting a RuntimeWarning if it is not used (e.g. in mocked tests). +        """          if infraction["expires_at"] is not None:              # Calculate the time remaining, in seconds, for the mute.              expiry = dateutil.parser.isoparse(infraction["expires_at"]) @@ -101,7 +107,7 @@ class InfractionScheduler:          # Allowing mod log since this is a passive action that should be logged.          try: -            await apply_coro +            await action()          except discord.HTTPException as e:              # When user joined and then right after this left again before action completed, this can't apply roles              if e.code == 10007 or e.status == 404: @@ -111,7 +117,7 @@ class InfractionScheduler:              else:                  log.exception(                      f"Got unexpected HTTPException (HTTP {e.status}, Discord code {e.code})" -                    f"when awaiting {infraction['type']} coroutine for {infraction['user']}." +                    f"when running {infraction['type']} action for {infraction['user']}."                  )          else:              log.info(f"Re-applied {infraction['type']} to user {infraction['user']} upon rejoining.") @@ -121,17 +127,20 @@ class InfractionScheduler:          ctx: Context,          infraction: _utils.Infraction,          user: MemberOrUser, -        action_coro: t.Optional[t.Awaitable] = None, +        action: t.Optional[Callable[[], Awaitable[None]]] = None,          user_reason: t.Optional[str] = None,          additional_info: str = "",      ) -> bool:          """          Apply an infraction to the user, log the infraction, and optionally notify the user. -        `action_coro`, if not provided, will result in the infraction not getting scheduled for deletion. +        `action`, if not provided, will result in the infraction not getting scheduled for deletion.          `user_reason`, if provided, will be sent to the user in place of the infraction reason.          `additional_info` will be attached to the text field in the mod-log embed. +        Note: The `action` provided is an async function rather than just a coroutine +        to prevent getting a RuntimeWarning if it is not used (e.g. in mocked tests). +          Returns whether or not the infraction succeeded.          """          infr_type = infraction["type"] @@ -200,10 +209,10 @@ class InfractionScheduler:          purge = infraction.get("purge", "")          # Execute the necessary actions to apply the infraction on Discord. -        if action_coro: -            log.trace(f"Awaiting the infraction #{id_} application action coroutine.") +        if action: +            log.trace(f"Running the infraction #{id_} application action.")              try: -                await action_coro +                await action()                  if expiry:                      # Schedule the expiration of the infraction.                      self.schedule_expiration(infraction) diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index 05cc74a03..fb2ab9579 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -54,8 +54,9 @@ class Infractions(InfractionScheduler, commands.Cog):          if active_mutes:              reason = f"Re-applying active mute: {active_mutes[0]['id']}" -            action = member.add_roles(self._muted_role, reason=reason) +            async def action() -> None: +                await member.add_roles(self._muted_role, reason=reason)              await self.reapply_infraction(active_mutes[0], action)      # region: Permanent infractions @@ -397,7 +398,7 @@ class Infractions(InfractionScheduler, commands.Cog):              log.trace(f"Attempting to kick {user} from voice because they've been muted.")              await user.move_to(None, reason=reason) -        await self.apply_infraction(ctx, infraction, user, action()) +        await self.apply_infraction(ctx, infraction, user, action)      @respect_role_hierarchy(member_arg=2)      async def apply_kick(self, ctx: Context, user: Member, reason: t.Optional[str], **kwargs) -> None: @@ -415,7 +416,9 @@ class Infractions(InfractionScheduler, commands.Cog):          if reason:              reason = textwrap.shorten(reason, width=512, placeholder="...") -        action = user.kick(reason=reason) +        async def action() -> None: +            await user.kick(reason=reason) +          await self.apply_infraction(ctx, infraction, user, action)      @respect_role_hierarchy(member_arg=2) @@ -464,7 +467,9 @@ class Infractions(InfractionScheduler, commands.Cog):          if reason:              reason = textwrap.shorten(reason, width=512, placeholder="...") -        action = ctx.guild.ban(user, reason=reason, delete_message_days=purge_days) +        async def action() -> None: +            await ctx.guild.ban(user, reason=reason, delete_message_days=purge_days) +          await self.apply_infraction(ctx, infraction, user, action)          bb_cog: t.Optional[BigBrother] = self.bot.get_cog("Big Brother") @@ -502,7 +507,7 @@ class Infractions(InfractionScheduler, commands.Cog):              await user.move_to(None, reason="Disconnected from voice to apply voice mute.")              await user.remove_roles(self._voice_verified_role, reason=reason) -        await self.apply_infraction(ctx, infraction, user, action()) +        await self.apply_infraction(ctx, infraction, user, action)      # endregion      # region: Base pardon functions diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index f2aab7a92..6cb2c3354 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -96,11 +96,12 @@ class Superstarify(InfractionScheduler, Cog):          if active_superstarifies:              infraction = active_superstarifies[0] -            action = member.edit( -                nick=self.get_nick(infraction["id"], member.id), -                reason=f"Superstarified member tried to escape the prison: {infraction['id']}" -            ) +            async def action() -> None: +                await member.edit( +                    nick=self.get_nick(infraction["id"], member.id), +                    reason=f"Superstarified member tried to escape the prison: {infraction['id']}" +                )              await self.reapply_infraction(infraction, action)      @command(name="superstarify", aliases=("force_nick", "star", "starify", "superstar")) @@ -175,7 +176,7 @@ class Superstarify(InfractionScheduler, Cog):          ).format          successful = await self.apply_infraction( -            ctx, infraction, member, action(), +            ctx, infraction, member, action,              user_reason=user_message(reason=f'**Additional details:** {reason}\n\n' if reason else ''),              additional_info=nickname_info          ) diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index a18a4d23b..ca9342550 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -35,17 +35,20 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction = AsyncMock()          self.bot.get_cog.return_value = AsyncMock()          self.cog.mod_log.ignore = Mock() -        self.ctx.guild.ban = Mock() +        self.ctx.guild.ban = AsyncMock()          await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) -        self.ctx.guild.ban.assert_called_once_with( +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar", "purge": ""}, self.target, ANY +        ) + +        action = self.cog.apply_infraction.call_args.args[-1] +        await action() +        self.ctx.guild.ban.assert_awaited_once_with(              self.target,              reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),              delete_message_days=0          ) -        self.cog.apply_infraction.assert_awaited_once_with( -            self.ctx, {"foo": "bar", "purge": ""}, self.target, self.ctx.guild.ban.return_value -        )      @patch("bot.exts.moderation.infraction._utils.post_infraction")      async def test_apply_kick_reason_truncation(self, post_infraction_mock): @@ -54,14 +57,17 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction = AsyncMock()          self.cog.mod_log.ignore = Mock() -        self.target.kick = Mock() +        self.target.kick = AsyncMock()          await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) -        self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))          self.cog.apply_infraction.assert_awaited_once_with( -            self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value +            self.ctx, {"foo": "bar"}, self.target, ANY          ) +        action = self.cog.apply_infraction.call_args.args[-1] +        await action() +        self.target.kick.assert_awaited_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +  @patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)  class VoiceMuteTests(unittest.IsolatedAsyncioTestCase): @@ -141,8 +147,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):      async def action_tester(self, action, reason: str) -> None:          """Helper method to test voice mute action.""" -        self.assertTrue(inspect.iscoroutine(action)) -        await action +        self.assertTrue(inspect.iscoroutinefunction(action)) +        await action()          self.user.move_to.assert_called_once_with(None, reason=ANY)          self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason=reason) @@ -195,8 +201,8 @@ class VoiceMuteTests(unittest.IsolatedAsyncioTestCase):          # Test action          action = self.cog.apply_infraction.call_args[0][-1] -        self.assertTrue(inspect.iscoroutine(action)) -        await action +        self.assertTrue(inspect.iscoroutinefunction(action)) +        await action()      async def test_voice_unmute_user_not_found(self):          """Should include info to return dict when user was not found from guild.""" | 
