diff options
-rw-r--r-- | bot/constants.py | 2 | ||||
-rw-r--r-- | bot/exts/moderation/silence.py | 273 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 394 | ||||
-rw-r--r-- | tests/helpers.py | 25 |
4 files changed, 611 insertions, 83 deletions
diff --git a/bot/constants.py b/bot/constants.py index 08ae0d52f..5da67f36a 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -401,6 +401,8 @@ class Channels(metaclass=YAMLGetter): code_help_chat_2: int code_help_voice_1: int code_help_voice_2: int + admins_voice: int + staff_voice: int cooldown: int defcon: int dev_contrib: int diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index e6712b3b6..45c3f5b92 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -2,35 +2,36 @@ import json import logging from contextlib import suppress from datetime import datetime, timedelta, timezone -from operator import attrgetter -from typing import Optional +from typing import Optional, Union from async_rediscache import RedisCache -from discord import TextChannel +from discord import Guild, PermissionOverwrite, TextChannel, VoiceChannel from discord.ext import commands, tasks from discord.ext.commands import Context +from bot import constants from bot.bot import Bot -from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles from bot.converters import HushDurationConverter -from bot.utils.lock import LockedResourceError, lock_arg +from bot.utils.lock import LockedResourceError, lock, lock_arg from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) LOCK_NAMESPACE = "silence" -MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." -MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." -MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." +MSG_SILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{constants.Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{constants.Emojis.check_mark} silenced current channel for {{duration}} minute(s)." -MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel was not silenced." MSG_UNSILENCE_MANUAL = ( - f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " + f"{constants.Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " f"set manually or the cache was prematurely cleared. " f"Please edit the overwrites manually to unsilence." ) -MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." +MSG_UNSILENCE_SUCCESS = f"{constants.Emojis.check_mark} unsilenced current channel." + +TextOrVoiceChannel = Union[TextChannel, VoiceChannel] class SilenceNotifier(tasks.Loop): @@ -41,7 +42,7 @@ class SilenceNotifier(tasks.Loop): self._silenced_channels = {} self._alert_channel = alert_channel - def add_channel(self, channel: TextChannel) -> None: + def add_channel(self, channel: TextOrVoiceChannel) -> None: """Add channel to `_silenced_channels` and start loop if not launched.""" if not self._silenced_channels: self.start() @@ -68,7 +69,9 @@ class SilenceNotifier(tasks.Loop): f"{channel.mention} for {(self._current_loop-start)//60} min" for channel, start in self._silenced_channels.items() ) - await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") + await self._alert_channel.send( + f"<@&{constants.Roles.moderators}> currently silenced channels: {channels_text}" + ) class Silence(commands.Cog): @@ -92,88 +95,234 @@ class Silence(commands.Cog): """Set instance attributes once the guild is available and reschedule unsilences.""" await self.bot.wait_until_guild_available() - guild = self.bot.get_guild(Guild.id) - self._verified_role = guild.get_role(Roles.verified) - self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) - self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) + guild = self.bot.get_guild(constants.Guild.id) + + self._verified_msg_role = guild.get_role(constants.Roles.verified) + self._verified_voice_role = guild.get_role(constants.Roles.voice_verified) + self._helper_role = guild.get_role(constants.Roles.helpers) + + self._mod_alerts_channel = self.bot.get_channel(constants.Channels.mod_alerts) + + self.notifier = SilenceNotifier(self.bot.get_channel(constants.Channels.mod_log)) await self._reschedule() + async def _get_related_text_channel(self, channel: VoiceChannel) -> Optional[TextChannel]: + """Returns the text channel related to a voice channel.""" + # TODO: Figure out a dynamic way of doing this + channels = { + "off-topic": constants.Channels.voice_chat, + "code/help 1": constants.Channels.code_help_voice, + "code/help 2": constants.Channels.code_help_voice_2, + "admin": constants.Channels.admins_voice, + "staff": constants.Channels.staff_voice + } + for name in channels.keys(): + if name in channel.name.lower(): + return self.bot.get_channel(channels[name]) + + async def send_message( + self, + message: str, + source_channel: TextChannel, + target_channel: TextOrVoiceChannel, + alert_target: bool = False + ) -> None: + """Helper function to send message confirmation to `source_channel`, and notification to `target_channel`.""" + # Reply to invocation channel + source_reply = message + if source_channel != target_channel: + source_reply = source_reply.replace("current channel", target_channel.mention) + await source_channel.send(source_reply) + + # Reply to target channel + if alert_target: + if isinstance(target_channel, VoiceChannel): + voice_chat = await self._get_related_text_channel(target_channel) + if voice_chat and source_channel != voice_chat: + await voice_chat.send(message.replace("current channel", target_channel.mention)) + elif source_channel != target_channel: + await target_channel.send(message) + + async def _select_lock_channel(*args) -> TextOrVoiceChannel: + """Passes the channel to be silenced to the resource lock.""" + channel = args[0].get("channel") + if channel is not None: + return channel + + else: + return args[0].get("ctx").channel + @commands.command(aliases=("hush",)) - @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True) - async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: + @lock(LOCK_NAMESPACE, _select_lock_channel, raise_error=True) + async def silence( + self, + ctx: Context, + duration: HushDurationConverter = 10, + kick: bool = False, + *, channel: TextOrVoiceChannel = None + ) -> None: """ Silence the current channel for `duration` minutes or `forever`. Duration is capped at 15 minutes, passing forever makes the silence indefinite. Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. + + Passing a voice channel will attempt to move members out of the channel and back to force sync permissions. + If `kick` is True, members will not be added back to the voice channel, and members will be unable to rejoin. """ await self._init_task - - channel_info = f"#{ctx.channel} ({ctx.channel.id})" + if channel is None: + channel = ctx.channel + channel_info = f"#{channel} ({channel.id})" log.debug(f"{ctx.author} is silencing channel {channel_info}.") - if not await self._set_silence_overwrites(ctx.channel): + if not await self._set_silence_overwrites(channel, kick): log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") - await ctx.send(MSG_SILENCE_FAIL) + await self.send_message(MSG_SILENCE_FAIL, ctx.channel, channel) return - await self._schedule_unsilence(ctx, duration) + if isinstance(channel, VoiceChannel): + if kick: + await self._kick_voice_members(channel) + else: + await self._force_voice_sync(channel) + + await self._schedule_unsilence(ctx, channel, duration) if duration is None: - self.notifier.add_channel(ctx.channel) + self.notifier.add_channel(channel) log.info(f"Silenced {channel_info} indefinitely.") - await ctx.send(MSG_SILENCE_PERMANENT) + await self.send_message(MSG_SILENCE_PERMANENT, ctx.channel, channel, True) + else: log.info(f"Silenced {channel_info} for {duration} minute(s).") - await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration)) + await self.send_message(MSG_SILENCE_SUCCESS.format(duration=duration), ctx.channel, channel, True) @commands.command(aliases=("unhush",)) - async def unsilence(self, ctx: Context) -> None: + async def unsilence(self, ctx: Context, *, channel: TextOrVoiceChannel = None) -> None: """ - Unsilence the current channel. + Unsilence the given channel if given, else the current one. If the channel was silenced indefinitely, notifications for the channel will stop. """ await self._init_task - log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") - await self._unsilence_wrapper(ctx.channel) + if channel is None: + channel = ctx.channel + log.debug(f"Unsilencing channel #{channel} from {ctx.author}'s command.") + await self._unsilence_wrapper(channel, ctx) @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) - async def _unsilence_wrapper(self, channel: TextChannel) -> None: + async def _unsilence_wrapper(self, channel: TextOrVoiceChannel, ctx: Optional[Context] = None) -> None: """Unsilence `channel` and send a success/failure message.""" + msg_channel = channel + if ctx is not None: + msg_channel = ctx.channel + if not await self._unsilence(channel): - overwrite = channel.overwrites_for(self._verified_role) - if overwrite.send_messages is False or overwrite.add_reactions is False: - await channel.send(MSG_UNSILENCE_MANUAL) + if isinstance(channel, VoiceChannel): + overwrite = channel.overwrites_for(self._verified_voice_role) + manual = overwrite.speak is False else: - await channel.send(MSG_UNSILENCE_FAIL) + overwrite = channel.overwrites_for(self._verified_msg_role) + manual = overwrite.send_messages is False or overwrite.add_reactions is False + + # Send fail message to muted channel or voice chat channel, and invocation channel + if manual: + await self.send_message(MSG_UNSILENCE_MANUAL, msg_channel, channel) + else: + await self.send_message(MSG_UNSILENCE_FAIL, msg_channel, channel) + else: - await channel.send(MSG_UNSILENCE_SUCCESS) + await self.send_message(MSG_UNSILENCE_SUCCESS, msg_channel, channel, True) - async def _set_silence_overwrites(self, channel: TextChannel) -> bool: + async def _set_silence_overwrites(self, channel: TextOrVoiceChannel, kick: bool = False) -> bool: """Set silence permission overwrites for `channel` and return True if successful.""" - overwrite = channel.overwrites_for(self._verified_role) - prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) + # Get the original channel overwrites + if isinstance(channel, TextChannel): + role = self._verified_msg_role + overwrite = channel.overwrites_for(role) + prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) + + else: + role = self._verified_voice_role + overwrite = channel.overwrites_for(role) + prev_overwrites = dict(speak=overwrite.speak) + if kick: + prev_overwrites.update(connect=overwrite.connect) + # Stop if channel was already silenced if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()): return False - overwrite.update(send_messages=False, add_reactions=False) - await channel.set_permissions(self._verified_role, overwrite=overwrite) + # Set new permissions, store + overwrite.update(**dict.fromkeys(prev_overwrites, False)) + await channel.set_permissions(role, overwrite=overwrite) await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites)) return True - async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: + @staticmethod + async def _get_afk_channel(guild: Guild) -> VoiceChannel: + """Get a guild's AFK channel, or create one if it does not exist.""" + afk_channel = guild.afk_channel + + if afk_channel is None: + overwrites = { + guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) + } + afk_channel = await guild.create_voice_channel("mute-temp", overwrites=overwrites) + log.info(f"Failed to get afk-channel, created temporary channel #{afk_channel} ({afk_channel.id})") + + return afk_channel + + async def _kick_voice_members(self, channel: VoiceChannel) -> None: + """Remove all non-staff members from a voice channel.""" + log.debug(f"Removing all non staff members from #{channel.name} ({channel.id}).") + + for member in channel.members: + if self._helper_role not in member.roles: + await member.move_to(None, reason="Kicking member from voice channel.") + + log.debug("Removed all members.") + + async def _force_voice_sync(self, channel: VoiceChannel) -> None: + """ + Move all non-staff members from `channel` to a temporary channel and back to force toggle role mute. + + Permission modification has to happen before this function. + """ + # Obtain temporary channel + delete_channel = channel.guild.afk_channel is None + afk_channel = await self._get_afk_channel(channel.guild) + + try: + # Move all members to temporary channel and back + for member in channel.members: + # Skip staff + if self._helper_role in member.roles: + continue + + await member.move_to(afk_channel, reason="Muting VC member.") + log.debug(f"Moved {member.name} to afk channel.") + + await member.move_to(channel, reason="Muting VC member.") + log.debug(f"Moved {member.name} to original voice channel.") + + finally: + # Delete VC channel if it was created. + if delete_channel: + await afk_channel.delete(reason="Deleting temp mute channel.") + + async def _schedule_unsilence(self, ctx: Context, channel: TextOrVoiceChannel, duration: Optional[int]) -> None: """Schedule `ctx.channel` to be unsilenced if `duration` is not None.""" if duration is None: - await self.unsilence_timestamps.set(ctx.channel.id, -1) + await self.unsilence_timestamps.set(channel.id, -1) else: - self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + self.scheduler.schedule_later(duration * 60, channel.id, ctx.invoke(self.unsilence, channel=channel)) unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) - await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) + await self.unsilence_timestamps.set(channel.id, unsilence_time.timestamp()) - async def _unsilence(self, channel: TextChannel) -> bool: + async def _unsilence(self, channel: TextOrVoiceChannel) -> bool: """ Unsilence `channel`. @@ -183,19 +332,34 @@ class Silence(commands.Cog): Return `True` if channel permissions were changed, `False` otherwise. """ + # Get stored overwrites, and return if channel is unsilenced prev_overwrites = await self.previous_overwrites.get(channel.id) if channel.id not in self.scheduler and prev_overwrites is None: log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") return False - overwrite = channel.overwrites_for(self._verified_role) + # Select the role based on channel type, and get current overwrites + if isinstance(channel, TextChannel): + role = self._verified_msg_role + overwrite = channel.overwrites_for(role) + permissions = "`Send Messages` and `Add Reactions`" + else: + role = self._verified_voice_role + overwrite = channel.overwrites_for(role) + permissions = "`Speak` and `Connect`" + + # Check if old overwrites were not stored if prev_overwrites is None: log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") - overwrite.update(send_messages=None, add_reactions=None) + overwrite.update(send_messages=None, add_reactions=None, speak=None, connect=None) else: overwrite.update(**json.loads(prev_overwrites)) - await channel.set_permissions(self._verified_role, overwrite=overwrite) + # Update Permissions + await channel.set_permissions(role, overwrite=overwrite) + if isinstance(channel, VoiceChannel): + await self._force_voice_sync(channel) + log.info(f"Unsilenced channel #{channel} ({channel.id}).") self.scheduler.cancel(channel.id) @@ -203,11 +367,12 @@ class Silence(commands.Cog): await self.previous_overwrites.delete(channel.id) await self.unsilence_timestamps.delete(channel.id) + # Alert Admin team if old overwrites were not available if prev_overwrites is None: await self._mod_alerts_channel.send( - f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " - f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " - f"overwrites for {self._verified_role.mention} are at their desired values." + f"<@&{constants.Roles.admins}> Restored overwrites with default values after unsilencing " + f"{channel.mention}. Please check that the {permissions} " + f"overwrites for {role.mention} are at their desired values." ) return True @@ -247,7 +412,7 @@ class Silence(commands.Cog): # This cannot be static (must have a __func__ attribute). async def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" - return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx) + return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx) def setup(bot: Bot) -> None: diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 104293d8e..635e017e3 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -2,14 +2,14 @@ import asyncio import unittest from datetime import datetime, timezone from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession from discord import PermissionOverwrite from bot.constants import Channels, Guild, Roles from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockTextChannel, MockVoiceChannel, autospec redis_session = None redis_loop = asyncio.get_event_loop() @@ -123,7 +123,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): guild.get_role.side_effect = lambda id_: Mock(id=id_) await self.cog._async_init() - self.assertEqual(self.cog._verified_role.id, Roles.verified) + self.assertEqual(self.cog._verified_msg_role.id, Roles.verified) @autospec(silence, "SilenceNotifier", pass_mocks=False) async def test_async_init_got_channels(self): @@ -158,7 +158,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(self.cog._init_task.cancelled()) @autospec("discord.ext.commands", "has_any_role") - @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) + @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" ctx = MockContext() @@ -168,6 +168,157 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): role_check.assert_called_once_with(*(1, 2, 3)) role_check.return_value.predicate.assert_awaited_once_with(ctx) + @mock.patch.object(silence.Silence, "_get_related_text_channel") + async def test_send_message(self, mock_get_related_text_channel): + """Test the send function reports to the correct channels.""" + text_channel_1 = MockTextChannel() + text_channel_2 = MockTextChannel() + + voice_channel = MockVoiceChannel() + voice_channel.name = "General/Offtopic" + voice_channel.mention = f"#{voice_channel.name}" + + mock_get_related_text_channel.return_value = text_channel_2 + + def reset(): + text_channel_1.reset_mock() + text_channel_2.reset_mock() + voice_channel.reset_mock() + + with self.subTest("Basic One Channel Test"): + await self.cog.send_message("Text basic message.", text_channel_1, text_channel_2, False) + text_channel_1.send.assert_called_once_with("Text basic message.") + text_channel_2.send.assert_not_called() + + reset() + with self.subTest("Basic Two Channel Test"): + await self.cog.send_message("Text basic message.", text_channel_1, text_channel_2, True) + text_channel_1.send.assert_called_once_with("Text basic message.") + text_channel_2.send.assert_called_once_with("Text basic message.") + + reset() + with self.subTest("Replacement One Channel Test"): + message = "Current. The following should be replaced: current channel." + await self.cog.send_message(message, text_channel_1, text_channel_2, False) + + text_channel_1.send.assert_called_once_with(message.replace("current channel", text_channel_1.mention)) + text_channel_2.send.assert_not_called() + + reset() + with self.subTest("Replacement Two Channel Test"): + message = "Current. The following should be replaced: current channel." + await self.cog.send_message(message, text_channel_1, text_channel_2, True) + + text_channel_1.send.assert_called_once_with(message.replace("current channel", text_channel_1.mention)) + text_channel_2.send.assert_called_once_with(message) + + reset() + with self.subTest("Text and Voice"): + await self.cog.send_message("This should show up just here", text_channel_1, voice_channel, False) + text_channel_1.send.assert_called_once_with("This should show up just here") + + reset() + with self.subTest("Text and Voice"): + message = "This should show up as current channel" + await self.cog.send_message(message, text_channel_1, voice_channel, True) + text_channel_1.send.assert_called_once_with(message.replace("current channel", voice_channel.mention)) + text_channel_2.send.assert_called_once_with(message.replace("current channel", voice_channel.mention)) + + reset() + with self.subTest("Text and Voice Same Invocation"): + message = "This should show up as current channel" + await self.cog.send_message(message, text_channel_2, voice_channel, True) + text_channel_2.send.assert_called_once_with(message.replace("current channel", voice_channel.mention)) + + async def test_get_related_text_channel(self): + """Tests the helper function that connects voice to text channels.""" + voice_channel = MockVoiceChannel() + + tests = ( + ("Off-Topic/General", Channels.voice_chat), + ("code/help 1", Channels.code_help_voice), + ("Staff", Channels.staff_voice), + ("ADMIN", Channels.admins_voice), + ("not in the channel list", None) + ) + + with mock.patch.object(self.cog.bot, "get_channel", lambda x: x): + for (name, channel_id) in tests: + voice_channel.name = name + voice_channel.id = channel_id + + result_id = await self.cog._get_related_text_channel(voice_channel) + self.assertEqual(result_id, channel_id) + + async def test_force_voice_sync(self): + """Tests the _force_voice_sync helper function.""" + await self.cog._async_init() + + members = [] + for _ in range(10): + members.append(MockMember()) + + afk_channel = MockVoiceChannel() + channel = MockVoiceChannel(guild=MockGuild(afk_channel=afk_channel), members=members) + + await self.cog._force_voice_sync(channel) + for member in members: + self.assertEqual(member.move_to.call_count, 2) + calls = [ + mock.call(afk_channel, reason="Muting VC member."), + mock.call(channel, reason="Muting VC member.") + ] + member.move_to.assert_has_calls(calls, any_order=False) + + async def test_force_voice_sync_staff(self): + """Tests to ensure _force_voice_sync does not kick staff members.""" + await self.cog._async_init() + member = MockMember(roles=[self.cog._helper_role]) + + await self.cog._force_voice_sync(MockVoiceChannel(members=[member])) + member.move_to.assert_not_called() + + async def test_force_voice_sync_no_channel(self): + """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" + await self.cog._async_init() + + channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) + new_channel = MockVoiceChannel(delete=AsyncMock()) + channel.guild.create_voice_channel.return_value = new_channel + + await self.cog._force_voice_sync(channel) + + # Check channel creation + overwrites = { + channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) + } + channel.guild.create_voice_channel.assert_called_once_with("mute-temp", overwrites=overwrites) + + # Check bot deleted channel + new_channel.delete.assert_called_once_with(reason="Deleting temp mute channel.") + + async def test_voice_kick(self): + """Test to ensure kick function can remove all members from a voice channel.""" + await self.cog._async_init() + + members = [] + for _ in range(10): + members.append(MockMember()) + + channel = MockVoiceChannel(members=members) + await self.cog._kick_voice_members(channel) + + for member in members: + member.move_to.assert_called_once_with(None, reason="Kicking member from voice channel.") + + async def test_voice_kick_staff(self): + """Test to ensure voice kick skips staff members.""" + await self.cog._async_init() + member = MockMember(roles=[self.cog._helper_role]) + + await self.cog._kick_voice_members(MockVoiceChannel(members=[member])) + member.move_to.assert_not_called() + @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) class RescheduleTests(unittest.IsolatedAsyncioTestCase): @@ -261,9 +412,13 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): asyncio.run(self.cog._async_init()) # Populate instance attributes. - self.channel = MockTextChannel() - self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) - self.channel.overwrites_for.return_value = self.overwrite + self.text_channel = MockTextChannel() + self.text_overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) + self.text_channel.overwrites_for.return_value = self.text_overwrite + + self.voice_channel = MockVoiceChannel() + self.voice_overwrite = PermissionOverwrite(speak=True) + self.voice_channel.overwrites_for.return_value = self.voice_overwrite async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" @@ -277,7 +432,60 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): with self.subTest(was_silenced=was_silenced, message=message, duration=duration): await self.cog.silence.callback(self.cog, ctx, duration) - ctx.send.assert_called_once_with(message) + ctx.channel.send.assert_called_once_with(message) + + @mock.patch.object(silence.Silence, "_set_silence_overwrites", return_value=True) + @mock.patch.object(silence.Silence, "_force_voice_sync") + async def test_sent_to_correct_channel(self, voice_sync, _): + """Test function sends messages to the correct channels.""" + text_channel = MockTextChannel() + voice_channel = MockVoiceChannel() + ctx = MockContext() + + test_cases = ( + (None, silence.MSG_SILENCE_SUCCESS.format(duration=10)), + (text_channel, silence.MSG_SILENCE_SUCCESS.format(duration=10)), + (voice_channel, silence.MSG_SILENCE_SUCCESS.format(duration=10)), + (ctx.channel, silence.MSG_SILENCE_SUCCESS.format(duration=10)), + ) + + for target, message in test_cases: + with self.subTest(target_channel=target, message=message): + await self.cog.silence.callback(self.cog, ctx, 10, False, channel=target) + if ctx.channel == target or target is None: + ctx.channel.send.assert_called_once_with(message) + + else: + ctx.channel.send.assert_called_once_with(message.replace("current channel", target.mention)) + if isinstance(target, MockTextChannel): + target.send.assert_called_once_with(message) + else: + voice_sync.assert_called_once_with(target) + + ctx.channel.send.reset_mock() + if target is not None and isinstance(target, MockTextChannel): + target.send.reset_mock() + + @mock.patch.object(silence.Silence, "_kick_voice_members") + @mock.patch.object(silence.Silence, "_force_voice_sync") + async def test_sync_or_kick_called(self, sync, kick): + """Tests if silence command calls kick or sync on voice channels when appropriate.""" + channel = MockVoiceChannel() + ctx = MockContext() + + with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): + with self.subTest("Test calls kick"): + await self.cog.silence.callback(self.cog, ctx, 10, kick=True, channel=channel) + kick.assert_called_once_with(channel) + sync.assert_not_called() + + kick.reset_mock() + sync.reset_mock() + + with self.subTest("Test calls sync"): + await self.cog.silence.callback(self.cog, ctx, 10, kick=False, channel=channel) + sync.assert_called_once_with(channel) + kick.assert_not_called() async def test_skipped_already_silenced(self): """Permissions were not set and `False` was returned for an already silenced channel.""" @@ -298,19 +506,19 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_silenced_channel(self): """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" - self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) - self.assertFalse(self.overwrite.send_messages) - self.assertFalse(self.overwrite.add_reactions) - self.channel.set_permissions.assert_awaited_once_with( - self.cog._verified_role, - overwrite=self.overwrite + self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) + self.assertFalse(self.text_overwrite.send_messages) + self.assertFalse(self.text_overwrite.add_reactions) + self.text_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_msg_role, + overwrite=self.text_overwrite ) async def test_preserved_other_overwrites(self): """Channel's other unrelated overwrites were not changed.""" - prev_overwrite_dict = dict(self.overwrite) - await self.cog._set_silence_overwrites(self.channel) - new_overwrite_dict = dict(self.overwrite) + prev_overwrite_dict = dict(self.text_overwrite) + await self.cog._set_silence_overwrites(self.text_channel) + new_overwrite_dict = dict(self.text_overwrite) # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. del prev_overwrite_dict['send_messages'] @@ -341,8 +549,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_previous_overwrites(self): """Channel's previous overwrites were cached.""" overwrite_json = '{"send_messages": true, "add_reactions": false}' - await self.cog._set_silence_overwrites(self.channel) - self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + await self.cog._set_silence_overwrites(self.text_channel) + self.cog.previous_overwrites.set.assert_called_once_with(self.text_channel.id, overwrite_json) @autospec(silence, "datetime") async def test_cached_unsilence_time(self, datetime_mock): @@ -352,7 +560,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): timestamp = now_timestamp + duration * 60 datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, duration) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) @@ -360,26 +568,37 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_indefinite_time(self): """A value of -1 was cached for a permanent silence.""" - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, None) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) async def test_scheduled_task(self): """An unsilence task was scheduled.""" - ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock()) await self.cog.silence.callback(self.cog, ctx, 5) args = (300, ctx.channel.id, ctx.invoke.return_value) self.cog.scheduler.schedule_later.assert_called_once_with(*args) - ctx.invoke.assert_called_once_with(self.cog.unsilence) + ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel) async def test_permanent_not_scheduled(self): """A task was not scheduled for a permanent silence.""" - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, None) self.cog.scheduler.schedule_later.assert_not_called() + async def test_correct_permission_updates(self): + """Tests if _set_silence_overwrites can correctly get and update permissions.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) + self.assertFalse(self.text_overwrite.send_messages or self.text_overwrite.add_reactions) + + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel)) + self.assertFalse(self.voice_overwrite.speak) + + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, True)) + self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect) + @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) class UnsilenceTests(unittest.IsolatedAsyncioTestCase): @@ -422,6 +641,42 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog.unsilence.callback(self.cog, ctx) ctx.channel.send.assert_called_once_with(message) + async def test_sent_to_correct_channel(self): + """Test function sends messages to the correct channels.""" + unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) + text_channel = MockTextChannel() + ctx = MockContext() + + test_cases = ( + (None, silence.MSG_UNSILENCE_SUCCESS.format(duration=10)), + (text_channel, silence.MSG_UNSILENCE_SUCCESS.format(duration=10)), + (ctx.channel, silence.MSG_UNSILENCE_SUCCESS.format(duration=10)), + ) + + for target, message in test_cases: + with self.subTest(target_channel=target, message=message): + with mock.patch.object(self.cog, "_unsilence", return_value=True): + # Assign Return + if ctx.channel == target or target is None: + ctx.channel.overwrites_for.return_value = unsilenced_overwrite + else: + target.overwrites_for.return_value = unsilenced_overwrite + + await self.cog.unsilence.callback(self.cog, ctx, channel=target) + + # Check Messages + if ctx.channel == target or target is None: + ctx.channel.send.assert_called_once_with(message) + else: + ctx.channel.send.assert_called_once_with( + message.replace("current channel", text_channel.mention) + ) + target.send.assert_called_once_with(message) + + ctx.channel.send.reset_mock() + if target is not None: + target.send.reset_mock() + async def test_skipped_already_unsilenced(self): """Permissions were not set and `False` was returned for an already unsilenced channel.""" self.cog.scheduler.__contains__.return_value = False @@ -435,7 +690,7 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): """Channel's `send_message` and `add_reactions` overwrites were restored.""" await self.cog._unsilence(self.channel) self.channel.set_permissions.assert_awaited_once_with( - self.cog._verified_role, + self.cog._verified_msg_role, overwrite=self.overwrite, ) @@ -449,7 +704,7 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog._unsilence(self.channel) self.channel.set_permissions.assert_awaited_once_with( - self.cog._verified_role, + self.cog._verified_msg_role, overwrite=self.overwrite, ) @@ -462,6 +717,10 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): await self.cog._unsilence(self.channel) self.cog._mod_alerts_channel.send.assert_awaited_once() + self.cog._mod_alerts_channel.send.reset_mock() + + await self.cog._unsilence(MockVoiceChannel()) + self.cog._mod_alerts_channel.send.assert_awaited_once() async def test_removed_notifier(self): """Channel was removed from `notifier`.""" @@ -500,3 +759,86 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): del new_overwrite_dict['add_reactions'] self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + @mock.patch.object(silence.Silence, "_unsilence", return_value=False) + @mock.patch.object(silence.Silence, "send_message") + async def test_unsilence_helper_fail(self, send_message, _): + """Tests unsilence_wrapper when `_unsilence` fails.""" + ctx = MockContext() + + text_channel = MockTextChannel() + text_role = self.cog.bot.get_guild(Guild.id).get_role(Roles.verified) + + voice_channel = MockVoiceChannel() + voice_role = self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified) + + test_cases = ( + (ctx, text_channel, text_role, True, silence.MSG_UNSILENCE_FAIL), + (ctx, text_channel, text_role, False, silence.MSG_UNSILENCE_MANUAL), + (ctx, voice_channel, voice_role, True, silence.MSG_UNSILENCE_FAIL), + (ctx, voice_channel, voice_role, False, silence.MSG_UNSILENCE_MANUAL), + ) + + class PermClass: + """Class to Mock return permissions""" + def __init__(self, value: bool): + self.val = value + + def __getattr__(self, item): + return self.val + + for context, channel, role, permission, message in test_cases: + with mock.patch.object(channel, "overwrites_for", return_value=PermClass(permission)) as overwrites: + with self.subTest(channel=channel, message=message): + await self.cog._unsilence_wrapper(channel, context) + + overwrites.assert_called_once_with(role) + send_message.assert_called_once_with(message, ctx.channel, channel) + + send_message.reset_mock() + + @mock.patch.object(silence.Silence, "_force_voice_sync") + @mock.patch.object(silence.Silence, "send_message") + async def test_correct_overwrites(self, send_message, _): + """Tests the overwrites returned by the _unsilence_wrapper are correct for voice and text channels.""" + ctx = MockContext() + + text_channel = MockTextChannel() + text_role = self.cog.bot.get_guild(Guild.id).get_role(Roles.verified) + + voice_channel = MockVoiceChannel() + voice_role = self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified) + + async def reset(): + await text_channel.set_permissions(text_role, PermissionOverwrite(send_messages=False, add_reactions=False)) + await voice_channel.set_permissions(voice_role, PermissionOverwrite(speak=False, connect=False)) + + text_channel.reset_mock() + voice_channel.reset_mock() + send_message.reset_mock() + await reset() + + default_text_overwrites = text_channel.overwrites_for(text_role) + default_voice_overwrites = voice_channel.overwrites_for(voice_role) + + test_cases = ( + (ctx, text_channel, text_role, default_text_overwrites, silence.MSG_UNSILENCE_SUCCESS), + (ctx, voice_channel, voice_role, default_voice_overwrites, silence.MSG_UNSILENCE_SUCCESS), + (ctx, ctx.channel, text_role, ctx.channel.overwrites_for(text_role), silence.MSG_UNSILENCE_SUCCESS), + (None, text_channel, text_role, default_text_overwrites, silence.MSG_UNSILENCE_SUCCESS), + ) + + for context, channel, role, overwrites, message in test_cases: + with self.subTest(ctx=context, channel=channel): + await self.cog._unsilence_wrapper(channel, context) + + if context is None: + send_message.assert_called_once_with(message, channel, channel, True) + else: + send_message.assert_called_once_with(message, context.channel, channel, True) + + channel.set_permissions.assert_called_once_with(role, overwrite=overwrites) + if channel != ctx.channel: + ctx.channel.send.assert_not_called() + + await reset() diff --git a/tests/helpers.py b/tests/helpers.py index 870f66197..5628ca31f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module - for logger in logging.Logger.manager.loggerDict.values(): # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ - spec_set = channel_instance + spec_set = text_channel_instance + + def __init__(self, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): + """ + A MagicMock subclass to mock VoiceChannel objects. + + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For + more information, see the `MockGuild` docstring. + """ + spec_set = voice_channel_instance def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} |