diff options
| author | 2020-06-12 18:27:33 -0700 | |
|---|---|---|
| committer | 2020-06-12 18:27:33 -0700 | |
| commit | 87028ab4bcbf8cda866bd7701c8f2559f556147f (patch) | |
| tree | 154f1c96bc8ea63fc9383ae9678244ae31e61ebd | |
| parent | Merge pull request #997 from python-discord/bug/frontend/996/charinfo-md-escape (diff) | |
| parent | Use class instead of NamedTuple (diff) | |
Merge pull request #978 from ItsDrike/unsilence-scheduler
Use Scheduler instead of asyncio.sleep on silence cog
| -rw-r--r-- | bot/cogs/moderation/silence.py | 37 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_silence.py | 18 | 
2 files changed, 45 insertions, 10 deletions
| diff --git a/bot/cogs/moderation/silence.py b/bot/cogs/moderation/silence.py index 25febfa51..c8ab6443b 100644 --- a/bot/cogs/moderation/silence.py +++ b/bot/cogs/moderation/silence.py @@ -1,7 +1,7 @@  import asyncio  import logging  from contextlib import suppress -from typing import Optional +from typing import NamedTuple, Optional  from discord import TextChannel  from discord.ext import commands, tasks @@ -11,10 +11,18 @@ from bot.bot import Bot  from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles  from bot.converters import HushDurationConverter  from bot.utils.checks import with_role_check +from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) +class TaskData(NamedTuple): +    """Data for a scheduled task.""" + +    delay: int +    ctx: Context + +  class SilenceNotifier(tasks.Loop):      """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -53,15 +61,25 @@ class SilenceNotifier(tasks.Loop):              await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") -class Silence(commands.Cog): +class Silence(Scheduler, commands.Cog):      """Commands for stopping channel messages for `verified` role in a channel."""      def __init__(self, bot: Bot): +        super().__init__()          self.bot = bot          self.muted_channels = set()          self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars())          self._get_instance_vars_event = asyncio.Event() +    async def _scheduled_task(self, task: TaskData) -> None: +        """Calls `self.unsilence` on expired silenced channel to unsilence it.""" +        await asyncio.sleep(task.delay) +        log.info("Unsilencing channel after set delay.") + +        # Because `self.unsilence` explicitly cancels this scheduled task, it is shielded +        # to avoid prematurely cancelling itself +        await asyncio.shield(task.ctx.invoke(self.unsilence)) +      async def _get_instance_vars(self) -> None:          """Get instance variables after they're available to get from the guild."""          await self.bot.wait_until_guild_available() @@ -90,9 +108,13 @@ class Silence(commands.Cog):              return          await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") -        await asyncio.sleep(duration*60) -        log.info("Unsilencing channel after set delay.") -        await ctx.invoke(self.unsilence) + +        task_data = TaskData( +            delay=duration*60, +            ctx=ctx +        ) + +        self.schedule_task(ctx.channel.id, task_data)      @commands.command(aliases=("unhush",))      async def unsilence(self, ctx: Context) -> None: @@ -103,7 +125,9 @@ class Silence(commands.Cog):          """          await self._get_instance_vars_event.wait()          log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") -        if await self._unsilence(ctx.channel): +        if not await self._unsilence(ctx.channel): +            await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") +        else:              await ctx.send(f"{Emojis.check_mark} unsilenced current channel.")      async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: @@ -140,6 +164,7 @@ class Silence(commands.Cog):          if current_overwrite.send_messages is False:              await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None))              log.info(f"Unsilenced channel #{channel} ({channel.id}).") +            self.cancel_task(channel.id)              self.notifier.remove_channel(channel)              self.muted_channels.discard(channel)              return True diff --git a/tests/bot/cogs/moderation/test_silence.py b/tests/bot/cogs/moderation/test_silence.py index 3fd149f04..ab3d0742a 100644 --- a/tests/bot/cogs/moderation/test_silence.py +++ b/tests/bot/cogs/moderation/test_silence.py @@ -127,10 +127,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):              self.ctx.reset_mock()      async def test_unsilence_sent_correct_discord_message(self): -        """Proper reply after a successful unsilence.""" -        with mock.patch.object(self.cog, "_unsilence", return_value=True): -            await self.cog.unsilence.callback(self.cog, self.ctx) -            self.ctx.send.assert_called_once_with(f"{Emojis.check_mark} unsilenced current channel.") +        """Check if proper message was sent when unsilencing channel.""" +        test_cases = ( +            (True, f"{Emojis.check_mark} unsilenced current channel."), +            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        ) +        for _unsilence_patch_return, result_message in test_cases: +            with self.subTest( +                starting_silenced_state=_unsilence_patch_return, +                result_message=result_message +            ): +                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): +                    await self.cog.unsilence.callback(self.cog, self.ctx) +                    self.ctx.send.assert_called_once_with(result_message) +            self.ctx.reset_mock()      async def test_silence_private_for_false(self):          """Permissions are not set and `False` is returned in an already silenced channel.""" | 
