diff options
-rw-r--r-- | bot/constants.py | 10 | ||||
-rw-r--r-- | bot/exts/moderation/silence.py | 287 | ||||
-rw-r--r-- | config-default.yml | 14 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 505 | ||||
-rw-r--r-- | tests/helpers.py | 25 |
5 files changed, 696 insertions, 145 deletions
diff --git a/bot/constants.py b/bot/constants.py index 7b2a38079..0c602f19b 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -449,15 +449,17 @@ class Channels(metaclass=YAMLGetter): staff_announcements: int admins_voice: int + code_help_voice_0: int code_help_voice_1: int - code_help_voice_2: int - general_voice: int + general_voice_0: int + general_voice_1: int staff_voice: int + code_help_chat_0: int code_help_chat_1: int - code_help_chat_2: int staff_voice_chat: int - voice_chat: int + voice_chat_0: int + voice_chat_1: int big_brother_logs: int talent_pool: int diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 2a7ca932e..616dfbefb 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -2,35 +2,44 @@ import json import logging from contextlib import suppress from datetime import datetime, timedelta, timezone -from operator import attrgetter -from typing import Optional +from typing import Optional, OrderedDict, Union from async_rediscache import RedisCache -from discord import TextChannel +from discord import Guild, PermissionOverwrite, TextChannel, VoiceChannel from discord.ext import commands, tasks from discord.ext.commands import Context +from bot import constants from bot.bot import Bot -from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles from bot.converters import HushDurationConverter -from bot.utils.lock import LockedResourceError, lock_arg +from bot.utils.lock import LockedResourceError, lock, lock_arg from bot.utils.scheduling import Scheduler log = logging.getLogger(__name__) LOCK_NAMESPACE = "silence" -MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." -MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." -MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." +MSG_SILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{constants.Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{constants.Emojis.check_mark} silenced current channel for {{duration}} minute(s)." -MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel was not silenced." MSG_UNSILENCE_MANUAL = ( - f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " + f"{constants.Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " f"set manually or the cache was prematurely cleared. " f"Please edit the overwrites manually to unsilence." ) -MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." +MSG_UNSILENCE_SUCCESS = f"{constants.Emojis.check_mark} unsilenced current channel." + +TextOrVoiceChannel = Union[TextChannel, VoiceChannel] + +VOICE_CHANNELS = { + constants.Channels.code_help_voice_0: constants.Channels.code_help_chat_0, + constants.Channels.code_help_voice_1: constants.Channels.code_help_chat_1, + constants.Channels.general_voice_0: constants.Channels.voice_chat_0, + constants.Channels.general_voice_1: constants.Channels.voice_chat_1, + constants.Channels.staff_voice: constants.Channels.staff_voice_chat, +} class SilenceNotifier(tasks.Loop): @@ -41,7 +50,7 @@ class SilenceNotifier(tasks.Loop): self._silenced_channels = {} self._alert_channel = alert_channel - def add_channel(self, channel: TextChannel) -> None: + def add_channel(self, channel: TextOrVoiceChannel) -> None: """Add channel to `_silenced_channels` and start loop if not launched.""" if not self._silenced_channels: self.start() @@ -68,7 +77,19 @@ class SilenceNotifier(tasks.Loop): f"{channel.mention} for {(self._current_loop-start)//60} min" for channel, start in self._silenced_channels.items() ) - await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}") + await self._alert_channel.send( + f"<@&{constants.Roles.moderators}> currently silenced channels: {channels_text}" + ) + + +async def _select_lock_channel(args: OrderedDict[str, any]) -> TextOrVoiceChannel: + """Passes the channel to be silenced to the resource lock.""" + channel = args["channel"] + if channel is not None: + return channel + + else: + return args["ctx"].channel class Silence(commands.Cog): @@ -92,88 +113,230 @@ class Silence(commands.Cog): """Set instance attributes once the guild is available and reschedule unsilences.""" await self.bot.wait_until_guild_available() - guild = self.bot.get_guild(Guild.id) + guild = self.bot.get_guild(constants.Guild.id) + self._everyone_role = guild.default_role - self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) - self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) + self._verified_voice_role = guild.get_role(constants.Roles.voice_verified) + + self._mod_alerts_channel = self.bot.get_channel(constants.Channels.mod_alerts) + + self.notifier = SilenceNotifier(self.bot.get_channel(constants.Channels.mod_log)) await self._reschedule() + async def send_message( + self, + message: str, + source_channel: TextChannel, + target_channel: TextOrVoiceChannel, + *, alert_target: bool = False + ) -> None: + """Helper function to send message confirmation to `source_channel`, and notification to `target_channel`.""" + # Reply to invocation channel + source_reply = message + if source_channel != target_channel: + source_reply = source_reply.replace("current channel", target_channel.mention) + await source_channel.send(source_reply) + + # Reply to target channel + if alert_target: + if isinstance(target_channel, VoiceChannel): + voice_chat = self.bot.get_channel(VOICE_CHANNELS.get(target_channel.id)) + if voice_chat and source_channel != voice_chat: + await voice_chat.send(message.replace("current channel", target_channel.mention)) + + elif source_channel != target_channel: + await target_channel.send(message) + @commands.command(aliases=("hush",)) - @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True) - async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None: + @lock(LOCK_NAMESPACE, _select_lock_channel, raise_error=True) + async def silence( + self, + ctx: Context, + duration: HushDurationConverter = 10, + channel: TextOrVoiceChannel = None, + *, kick: bool = False + ) -> None: """ Silence the current channel for `duration` minutes or `forever`. Duration is capped at 15 minutes, passing forever makes the silence indefinite. Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start. + + Passing a voice channel will attempt to move members out of the channel and back to force sync permissions. + If `kick` is True, members will not be added back to the voice channel, and members will be unable to rejoin. """ await self._init_task - - channel_info = f"#{ctx.channel} ({ctx.channel.id})" + if channel is None: + channel = ctx.channel + channel_info = f"#{channel} ({channel.id})" log.debug(f"{ctx.author} is silencing channel {channel_info}.") - if not await self._set_silence_overwrites(ctx.channel): + if not await self._set_silence_overwrites(channel, kick=kick): log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") - await ctx.send(MSG_SILENCE_FAIL) + await self.send_message(MSG_SILENCE_FAIL, ctx.channel, channel, alert_target=False) return - await self._schedule_unsilence(ctx, duration) + if isinstance(channel, VoiceChannel): + if kick: + await self._kick_voice_members(channel) + else: + await self._force_voice_sync(channel) + + await self._schedule_unsilence(ctx, channel, duration) if duration is None: - self.notifier.add_channel(ctx.channel) + self.notifier.add_channel(channel) log.info(f"Silenced {channel_info} indefinitely.") - await ctx.send(MSG_SILENCE_PERMANENT) + await self.send_message(MSG_SILENCE_PERMANENT, ctx.channel, channel, alert_target=True) + else: log.info(f"Silenced {channel_info} for {duration} minute(s).") - await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration)) + formatted_message = MSG_SILENCE_SUCCESS.format(duration=duration) + await self.send_message(formatted_message, ctx.channel, channel, alert_target=True) @commands.command(aliases=("unhush",)) - async def unsilence(self, ctx: Context) -> None: + async def unsilence(self, ctx: Context, *, channel: TextOrVoiceChannel = None) -> None: """ - Unsilence the current channel. + Unsilence the given channel if given, else the current one. If the channel was silenced indefinitely, notifications for the channel will stop. """ await self._init_task - log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") - await self._unsilence_wrapper(ctx.channel) + if channel is None: + channel = ctx.channel + log.debug(f"Unsilencing channel #{channel} from {ctx.author}'s command.") + await self._unsilence_wrapper(channel, ctx) @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) - async def _unsilence_wrapper(self, channel: TextChannel) -> None: - """Unsilence `channel` and send a success/failure message.""" + async def _unsilence_wrapper(self, channel: TextOrVoiceChannel, ctx: Optional[Context] = None) -> None: + """ + Unsilence `channel` and send a success/failure message to ctx.channel. + + If ctx is None or not passed, `channel` is used in its place. + If `channel` and ctx.channel are the same, only one message is sent. + """ + msg_channel = channel + if ctx is not None: + msg_channel = ctx.channel + if not await self._unsilence(channel): - overwrite = channel.overwrites_for(self._everyone_role) - if overwrite.send_messages is False or overwrite.add_reactions is False: - await channel.send(MSG_UNSILENCE_MANUAL) + if isinstance(channel, VoiceChannel): + overwrite = channel.overwrites_for(self._verified_voice_role) + manual = overwrite.speak is False else: - await channel.send(MSG_UNSILENCE_FAIL) + overwrite = channel.overwrites_for(self._everyone_role) + manual = overwrite.send_messages is False or overwrite.add_reactions is False + + # Send fail message to muted channel or voice chat channel, and invocation channel + if manual: + await self.send_message(MSG_UNSILENCE_MANUAL, msg_channel, channel, alert_target=False) + else: + await self.send_message(MSG_UNSILENCE_FAIL, msg_channel, channel, alert_target=False) + else: - await channel.send(MSG_UNSILENCE_SUCCESS) + await self.send_message(MSG_UNSILENCE_SUCCESS, msg_channel, channel, alert_target=True) - async def _set_silence_overwrites(self, channel: TextChannel) -> bool: + async def _set_silence_overwrites(self, channel: TextOrVoiceChannel, *, kick: bool = False) -> bool: """Set silence permission overwrites for `channel` and return True if successful.""" - overwrite = channel.overwrites_for(self._everyone_role) - prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) + # Get the original channel overwrites + if isinstance(channel, TextChannel): + role = self._everyone_role + overwrite = channel.overwrites_for(role) + prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) + else: + role = self._verified_voice_role + overwrite = channel.overwrites_for(role) + prev_overwrites = dict(speak=overwrite.speak) + if kick: + prev_overwrites.update(connect=overwrite.connect) + + # Stop if channel was already silenced if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()): return False - overwrite.update(send_messages=False, add_reactions=False) - await channel.set_permissions(self._everyone_role, overwrite=overwrite) + # Set new permissions, store + overwrite.update(**dict.fromkeys(prev_overwrites, False)) + await channel.set_permissions(role, overwrite=overwrite) await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites)) return True - async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: + @staticmethod + async def _get_afk_channel(guild: Guild) -> VoiceChannel: + """Get a guild's AFK channel, or create one if it does not exist.""" + afk_channel = guild.afk_channel + + if afk_channel is None: + overwrites = { + guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) + } + afk_channel = await guild.create_voice_channel("mute-temp", overwrites=overwrites) + log.info(f"Failed to get afk-channel, created temporary channel #{afk_channel} ({afk_channel.id})") + + return afk_channel + + @staticmethod + async def _kick_voice_members(channel: VoiceChannel) -> None: + """Remove all non-staff members from a voice channel.""" + log.debug(f"Removing all non staff members from #{channel.name} ({channel.id}).") + + for member in channel.members: + # Skip staff + if any(role.id in constants.MODERATION_ROLES for role in member.roles): + continue + + try: + await member.move_to(None, reason="Kicking member from voice channel.") + log.debug(f"Kicked {member.name} from voice channel.") + except Exception as e: + log.debug(f"Failed to move {member.name}. Reason: {e}") + continue + + log.debug("Removed all members.") + + async def _force_voice_sync(self, channel: VoiceChannel) -> None: + """ + Move all non-staff members from `channel` to a temporary channel and back to force toggle role mute. + + Permission modification has to happen before this function. + """ + # Obtain temporary channel + delete_channel = channel.guild.afk_channel is None + afk_channel = await self._get_afk_channel(channel.guild) + + try: + # Move all members to temporary channel and back + for member in channel.members: + # Skip staff + if any(role.id in constants.MODERATION_ROLES for role in member.roles): + continue + + try: + await member.move_to(afk_channel, reason="Muting VC member.") + log.debug(f"Moved {member.name} to afk channel.") + + await member.move_to(channel, reason="Muting VC member.") + log.debug(f"Moved {member.name} to original voice channel.") + except Exception as e: + log.debug(f"Failed to move {member.name}. Reason: {e}") + continue + + finally: + # Delete VC channel if it was created. + if delete_channel: + await afk_channel.delete(reason="Deleting temporary mute channel.") + + async def _schedule_unsilence(self, ctx: Context, channel: TextOrVoiceChannel, duration: Optional[int]) -> None: """Schedule `ctx.channel` to be unsilenced if `duration` is not None.""" if duration is None: - await self.unsilence_timestamps.set(ctx.channel.id, -1) + await self.unsilence_timestamps.set(channel.id, -1) else: - self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) + self.scheduler.schedule_later(duration * 60, channel.id, ctx.invoke(self.unsilence, channel=channel)) unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) - await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) + await self.unsilence_timestamps.set(channel.id, unsilence_time.timestamp()) - async def _unsilence(self, channel: TextChannel) -> bool: + async def _unsilence(self, channel: TextOrVoiceChannel) -> bool: """ Unsilence `channel`. @@ -183,19 +346,34 @@ class Silence(commands.Cog): Return `True` if channel permissions were changed, `False` otherwise. """ + # Get stored overwrites, and return if channel is unsilenced prev_overwrites = await self.previous_overwrites.get(channel.id) if channel.id not in self.scheduler and prev_overwrites is None: log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") return False - overwrite = channel.overwrites_for(self._everyone_role) + # Select the role based on channel type, and get current overwrites + if isinstance(channel, TextChannel): + role = self._everyone_role + overwrite = channel.overwrites_for(role) + permissions = "`Send Messages` and `Add Reactions`" + else: + role = self._verified_voice_role + overwrite = channel.overwrites_for(role) + permissions = "`Speak` and `Connect`" + + # Check if old overwrites were not stored if prev_overwrites is None: log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") - overwrite.update(send_messages=None, add_reactions=None) + overwrite.update(send_messages=None, add_reactions=None, speak=None, connect=None) else: overwrite.update(**json.loads(prev_overwrites)) - await channel.set_permissions(self._everyone_role, overwrite=overwrite) + # Update Permissions + await channel.set_permissions(role, overwrite=overwrite) + if isinstance(channel, VoiceChannel): + await self._force_voice_sync(channel) + log.info(f"Unsilenced channel #{channel} ({channel.id}).") self.scheduler.cancel(channel.id) @@ -203,11 +381,12 @@ class Silence(commands.Cog): await self.previous_overwrites.delete(channel.id) await self.unsilence_timestamps.delete(channel.id) + # Alert Admin team if old overwrites were not available if prev_overwrites is None: await self._mod_alerts_channel.send( - f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " - f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " - f"overwrites for {self._everyone_role.mention} are at their desired values." + f"<@&{constants.Roles.admins}> Restored overwrites with default values after unsilencing " + f"{channel.mention}. Please check that the {permissions} " + f"overwrites for {role.mention} are at their desired values." ) return True @@ -247,7 +426,7 @@ class Silence(commands.Cog): # This cannot be static (must have a __func__ attribute). async def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" - return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx) + return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx) def setup(bot: Bot) -> None: diff --git a/config-default.yml b/config-default.yml index 46475f845..b9f6b40ac 100644 --- a/config-default.yml +++ b/config-default.yml @@ -210,16 +210,18 @@ guild: # Voice Channels admins_voice: &ADMINS_VOICE 500734494840717332 - code_help_voice_1: 751592231726481530 - code_help_voice_2: 764232549840846858 - general_voice: 751591688538947646 + code_help_voice_0: 751592231726481530 + code_help_voice_1: 764232549840846858 + general_voice_0: 751591688538947646 + general_voice_1: 799641437645701151 staff_voice: &STAFF_VOICE 412375055910043655 # Voice Chat - code_help_chat_1: 755154969761677312 - code_help_chat_2: 766330079135268884 + code_help_chat_0: 755154969761677312 + code_help_chat_1: 766330079135268884 staff_voice_chat: 541638762007101470 - voice_chat: 412357430186344448 + voice_chat_0: 412357430186344448 + voice_chat_1: 799647045886541885 # Watch big_brother_logs: &BB_LOGS 468507907357409333 diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index fa5fc9e81..729b28412 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,15 +1,25 @@ import asyncio import unittest from datetime import datetime, timezone +from typing import List, Tuple from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession from discord import PermissionOverwrite -from bot.constants import Channels, Guild, Roles +from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import ( + MockBot, + MockContext, + MockGuild, + MockMember, + MockRole, + MockTextChannel, + MockVoiceChannel, + autospec +) redis_session = None redis_loop = asyncio.get_event_loop() @@ -149,7 +159,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(self.cog._init_task.cancelled()) @autospec("discord.ext.commands", "has_any_role") - @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) + @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" ctx = MockContext() @@ -159,6 +169,95 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): role_check.assert_called_once_with(*(1, 2, 3)) role_check.return_value.predicate.assert_awaited_once_with(ctx) + async def test_force_voice_sync(self): + """Tests the _force_voice_sync helper function.""" + await self.cog._async_init() + + # Create a regular member, and one member for each of the moderation roles + moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] + members = [MockMember(), *moderation_members] + + channel = MockVoiceChannel(members=members) + + await self.cog._force_voice_sync(channel) + for member in members: + if member in moderation_members: + member.move_to.assert_not_called() + else: + self.assertEqual(member.move_to.call_count, 2) + calls = member.move_to.call_args_list + + # Tests that the member was moved to the afk channel, and back. + self.assertEqual((channel.guild.afk_channel,), calls[0].args) + self.assertEqual((channel,), calls[1].args) + + async def test_force_voice_sync_no_channel(self): + """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" + await self.cog._async_init() + + channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) + new_channel = MockVoiceChannel(delete=AsyncMock()) + channel.guild.create_voice_channel.return_value = new_channel + + await self.cog._force_voice_sync(channel) + + # Check channel creation + overwrites = { + channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) + } + channel.guild.create_voice_channel.assert_awaited_once_with("mute-temp", overwrites=overwrites) + + # Check bot deleted channel + new_channel.delete.assert_awaited_once() + + async def test_voice_kick(self): + """Test to ensure kick function can remove all members from a voice channel.""" + await self.cog._async_init() + + # Create a regular member, and one member for each of the moderation roles + moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] + members = [MockMember(), *moderation_members] + + channel = MockVoiceChannel(members=members) + await self.cog._kick_voice_members(channel) + + for member in members: + if member in moderation_members: + member.move_to.assert_not_called() + else: + self.assertEqual((None,), member.move_to.call_args_list[0].args) + + @staticmethod + def create_erroneous_members() -> Tuple[List[MockMember], List[MockMember]]: + """ + Helper method to generate a list of members that error out on move_to call. + + Returns the list of erroneous members, + as well as a list of regular and erroneous members combined, in that order. + """ + erroneous_member = MockMember(move_to=AsyncMock(side_effect=Exception())) + members = [MockMember(), erroneous_member] + + return erroneous_member, members + + async def test_kick_move_to_error(self): + """Test to ensure move_to gets called on all members during kick, even if some fail.""" + await self.cog._async_init() + _, members = self.create_erroneous_members() + + await self.cog._kick_voice_members(MockVoiceChannel(members=members)) + for member in members: + member.move_to.assert_awaited_once() + + async def test_sync_move_to_error(self): + """Test to ensure move_to gets called on all members during sync, even if some fail.""" + await self.cog._async_init() + failing_member, members = self.create_erroneous_members() + + await self.cog._force_voice_sync(MockVoiceChannel(members=members)) + for member in members: + self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) + @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) class RescheduleTests(unittest.IsolatedAsyncioTestCase): @@ -235,6 +334,16 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase): self.cog.notifier.add_channel.assert_not_called() +def voice_sync_helper(function): + """Helper wrapper to test the sync and kick functions for voice channels.""" + @autospec(silence.Silence, "_force_voice_sync", "_kick_voice_members", "_set_silence_overwrites") + async def inner(self, sync, kick, overwrites): + overwrites.return_value = True + await function(self, MockContext(), sync, kick) + + return inner + + @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) class SilenceTests(unittest.IsolatedAsyncioTestCase): """Tests for the silence command and its related helper methods.""" @@ -242,7 +351,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "_reschedule", pass_mocks=False) @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot() + self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) self.cog._init_task = asyncio.Future() self.cog._init_task.set_result(None) @@ -252,56 +361,126 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): asyncio.run(self.cog._async_init()) # Populate instance attributes. - self.channel = MockTextChannel() - self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) - self.channel.overwrites_for.return_value = self.overwrite + self.text_channel = MockTextChannel() + self.text_overwrite = PermissionOverwrite(send_messages=True, add_reactions=False) + self.text_channel.overwrites_for.return_value = self.text_overwrite + + self.voice_channel = MockVoiceChannel() + self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) + self.voice_channel.overwrites_for.return_value = self.voice_overwrite async def test_sent_correct_message(self): - """Appropriate failure/success message was sent by the command.""" + """Appropriate failure/success message was sent by the command to the correct channel.""" + # The following test tuples are made up of: + # duration, expected message, and the success of the _set_silence_overwrites function test_cases = ( (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), (None, silence.MSG_SILENCE_PERMANENT, True,), (5, silence.MSG_SILENCE_FAIL, False,), ) + for duration, message, was_silenced in test_cases: - ctx = MockContext() with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): - with self.subTest(was_silenced=was_silenced, message=message, duration=duration): - await self.cog.silence.callback(self.cog, ctx, duration) - ctx.send.assert_called_once_with(message) + for target in [MockTextChannel(), MockVoiceChannel(), None]: + with self.subTest(was_silenced=was_silenced, target=target, message=message): + with mock.patch.object(self.cog, "send_message") as send_message: + ctx = MockContext() + await self.cog.silence.callback(self.cog, ctx, duration, target) + send_message.assert_called_once_with( + message, + ctx.channel, + target or ctx.channel, + alert_target=was_silenced + ) + + @voice_sync_helper + async def test_sync_called(self, ctx, sync, kick): + """Tests if silence command calls sync on a voice channel.""" + channel = MockVoiceChannel() + await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False) + + sync.assert_awaited_once_with(self.cog, channel) + kick.assert_not_called() + + @voice_sync_helper + async def test_kick_called(self, ctx, sync, kick): + """Tests if silence command calls kick on a voice channel.""" + channel = MockVoiceChannel() + await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True) + + kick.assert_awaited_once_with(channel) + sync.assert_not_called() + + @voice_sync_helper + async def test_sync_not_called(self, ctx, sync, kick): + """Tests that silence command does not call sync on a text channel.""" + channel = MockTextChannel() + await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False) + + sync.assert_not_called() + kick.assert_not_called() + + @voice_sync_helper + async def test_kick_not_called(self, ctx, sync, kick): + """Tests that silence command does not call kick on a text channel.""" + channel = MockTextChannel() + await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True) + + sync.assert_not_called() + kick.assert_not_called() async def test_skipped_already_silenced(self): """Permissions were not set and `False` was returned for an already silenced channel.""" subtests = ( - (False, PermissionOverwrite(send_messages=False, add_reactions=False)), - (True, PermissionOverwrite(send_messages=True, add_reactions=True)), - (True, PermissionOverwrite(send_messages=False, add_reactions=False)), + (False, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), + (True, MockTextChannel(), PermissionOverwrite(send_messages=True, add_reactions=True)), + (True, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), + (False, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)), + (True, MockVoiceChannel(), PermissionOverwrite(connect=True, speak=True)), + (True, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)), ) - for contains, overwrite in subtests: - with self.subTest(contains=contains, overwrite=overwrite): + for contains, channel, overwrite in subtests: + with self.subTest(contains=contains, is_text=isinstance(channel, MockTextChannel), overwrite=overwrite): self.cog.scheduler.__contains__.return_value = contains - channel = MockTextChannel() channel.overwrites_for.return_value = overwrite self.assertFalse(await self.cog._set_silence_overwrites(channel)) channel.set_permissions.assert_not_called() - async def test_silenced_channel(self): + async def test_silenced_text_channel(self): """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" - self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) - self.assertFalse(self.overwrite.send_messages) - self.assertFalse(self.overwrite.add_reactions) - self.channel.set_permissions.assert_awaited_once_with( + self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) + self.assertFalse(self.text_overwrite.send_messages) + self.assertFalse(self.text_overwrite.add_reactions) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite + overwrite=self.text_overwrite + ) + + async def test_silenced_voice_channel_speak(self): + """Channel had `speak` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel)) + self.assertFalse(self.voice_overwrite.speak) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite ) - async def test_preserved_other_overwrites(self): - """Channel's other unrelated overwrites were not changed.""" - prev_overwrite_dict = dict(self.overwrite) - await self.cog._set_silence_overwrites(self.channel) - new_overwrite_dict = dict(self.overwrite) + async def test_silenced_voice_channel_full(self): + """Channel had `speak` and `connect` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, kick=True)) + self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite + ) + + async def test_preserved_other_overwrites_text(self): + """Channel's other unrelated overwrites were not changed for a text channel mute.""" + prev_overwrite_dict = dict(self.text_overwrite) + await self.cog._set_silence_overwrites(self.text_channel) + new_overwrite_dict = dict(self.text_overwrite) # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. del prev_overwrite_dict['send_messages'] @@ -311,6 +490,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + async def test_preserved_other_overwrites_voice(self): + """Channel's other unrelated overwrites were not changed for a voice channel mute.""" + prev_overwrite_dict = dict(self.voice_overwrite) + await self.cog._set_silence_overwrites(self.voice_channel) + new_overwrite_dict = dict(self.voice_overwrite) + + # Remove 'connect' & 'speak' keys because they were changed by the method. + del prev_overwrite_dict['connect'] + del prev_overwrite_dict['speak'] + del new_overwrite_dict['connect'] + del new_overwrite_dict['speak'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + async def test_temp_not_added_to_notifier(self): """Channel was not added to notifier if a duration was set for the silence.""" with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): @@ -332,8 +525,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_previous_overwrites(self): """Channel's previous overwrites were cached.""" overwrite_json = '{"send_messages": true, "add_reactions": false}' - await self.cog._set_silence_overwrites(self.channel) - self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + await self.cog._set_silence_overwrites(self.text_channel) + self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json) @autospec(silence, "datetime") async def test_cached_unsilence_time(self, datetime_mock): @@ -343,7 +536,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): timestamp = now_timestamp + duration * 60 datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, duration) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) @@ -351,23 +544,23 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_indefinite_time(self): """A value of -1 was cached for a permanent silence.""" - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, None) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) async def test_scheduled_task(self): """An unsilence task was scheduled.""" - ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock()) await self.cog.silence.callback(self.cog, ctx, 5) args = (300, ctx.channel.id, ctx.invoke.return_value) self.cog.scheduler.schedule_later.assert_called_once_with(*args) - ctx.invoke.assert_called_once_with(self.cog.unsilence) + ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel) async def test_permanent_not_scheduled(self): """A task was not scheduled for a permanent silence.""" - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, None) self.cog.scheduler.schedule_later.assert_not_called() @@ -391,9 +584,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.cog.scheduler.__contains__.return_value = True overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' - self.channel = MockTextChannel() - self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) - self.channel.overwrites_for.return_value = self.overwrite + self.text_channel = MockTextChannel() + self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) + self.text_channel.overwrites_for.return_value = self.text_overwrite + + self.voice_channel = MockVoiceChannel() + self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) + self.voice_channel.overwrites_for.return_value = self.voice_overwrite async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" @@ -401,88 +598,128 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): test_cases = ( (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), - (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, self.text_overwrite), (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), ) + for was_unsilenced, message, overwrite in test_cases: ctx = MockContext() - with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): + + for target in [None, MockTextChannel()]: + ctx.channel.overwrites_for.return_value = overwrite + if target: + target.overwrites_for.return_value = overwrite + with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): - ctx.channel.overwrites_for.return_value = overwrite - await self.cog.unsilence.callback(self.cog, ctx) - ctx.channel.send.assert_called_once_with(message) + with mock.patch.object(self.cog, "send_message") as send_message: + with self.subTest(was_unsilenced=was_unsilenced, overwrite=overwrite, target=target): + await self.cog.unsilence.callback(self.cog, ctx, channel=target) + + call_args = (message, ctx.channel, target or ctx.channel) + send_message.assert_awaited_once_with(*call_args, alert_target=was_unsilenced) async def test_skipped_already_unsilenced(self): """Permissions were not set and `False` was returned for an already unsilenced channel.""" self.cog.scheduler.__contains__.return_value = False self.cog.previous_overwrites.get.return_value = None - channel = MockTextChannel() - self.assertFalse(await self.cog._unsilence(channel)) - channel.set_permissions.assert_not_called() + for channel in (MockVoiceChannel(), MockTextChannel()): + with self.subTest(channel=channel): + self.assertFalse(await self.cog._unsilence(channel)) + channel.set_permissions.assert_not_called() - async def test_restored_overwrites(self): - """Channel's `send_message` and `add_reactions` overwrites were restored.""" - await self.cog._unsilence(self.channel) - self.channel.set_permissions.assert_awaited_once_with( + async def test_restored_overwrites_text(self): + """Text channel's `send_message` and `add_reactions` overwrites were restored.""" + await self.cog._unsilence(self.text_channel) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite, + overwrite=self.text_overwrite, + ) + + # Recall that these values are determined by the fixture. + self.assertTrue(self.text_overwrite.send_messages) + self.assertFalse(self.text_overwrite.add_reactions) + + async def test_restored_overwrites_voice(self): + """Voice channel's `connect` and `speak` overwrites were restored.""" + await self.cog._unsilence(self.voice_channel) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite, ) # Recall that these values are determined by the fixture. - self.assertTrue(self.overwrite.send_messages) - self.assertFalse(self.overwrite.add_reactions) + self.assertTrue(self.voice_overwrite.connect) + self.assertTrue(self.voice_overwrite.speak) - async def test_cache_miss_used_default_overwrites(self): - """Both overwrites were set to None due previous values not being found in the cache.""" + async def test_cache_miss_used_default_overwrites_text(self): + """Text overwrites were set to None due previous values not being found in the cache.""" self.cog.previous_overwrites.get.return_value = None - await self.cog._unsilence(self.channel) - self.channel.set_permissions.assert_awaited_once_with( + await self.cog._unsilence(self.text_channel) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite, + overwrite=self.text_overwrite, ) - self.assertIsNone(self.overwrite.send_messages) - self.assertIsNone(self.overwrite.add_reactions) + self.assertIsNone(self.text_overwrite.send_messages) + self.assertIsNone(self.text_overwrite.add_reactions) - async def test_cache_miss_sent_mod_alert(self): - """A message was sent to the mod alerts channel.""" + async def test_cache_miss_used_default_overwrites_voice(self): + """Voice overwrites were set to None due previous values not being found in the cache.""" self.cog.previous_overwrites.get.return_value = None - await self.cog._unsilence(self.channel) + await self.cog._unsilence(self.voice_channel) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite, + ) + + self.assertIsNone(self.voice_overwrite.connect) + self.assertIsNone(self.voice_overwrite.speak) + + async def test_cache_miss_sent_mod_alert_text(self): + """A message was sent to the mod alerts channel upon muting a text channel.""" + self.cog.previous_overwrites.get.return_value = None + await self.cog._unsilence(self.text_channel) + self.cog._mod_alerts_channel.send.assert_awaited_once() + + async def test_cache_miss_sent_mod_alert_voice(self): + """A message was sent to the mod alerts channel upon muting a voice channel.""" + self.cog.previous_overwrites.get.return_value = None + await self.cog._unsilence(MockVoiceChannel()) self.cog._mod_alerts_channel.send.assert_awaited_once() async def test_removed_notifier(self): """Channel was removed from `notifier`.""" - await self.cog._unsilence(self.channel) - self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + await self.cog._unsilence(self.text_channel) + self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel) async def test_deleted_cached_overwrite(self): """Channel was deleted from the overwrites cache.""" - await self.cog._unsilence(self.channel) - self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id) async def test_deleted_cached_time(self): """Channel was deleted from the timestamp cache.""" - await self.cog._unsilence(self.channel) - self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id) async def test_cancelled_task(self): """The scheduled unsilence task should be cancelled.""" - await self.cog._unsilence(self.channel) - self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.scheduler.cancel.assert_called_once_with(self.text_channel.id) - async def test_preserved_other_overwrites(self): - """Channel's other unrelated overwrites were not changed, including cache misses.""" + async def test_preserved_other_overwrites_text(self): + """Text channel's other unrelated overwrites were not changed, including cache misses.""" for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): with self.subTest(overwrite_json=overwrite_json): self.cog.previous_overwrites.get.return_value = overwrite_json - prev_overwrite_dict = dict(self.overwrite) - await self.cog._unsilence(self.channel) - new_overwrite_dict = dict(self.overwrite) + prev_overwrite_dict = dict(self.text_overwrite) + await self.cog._unsilence(self.text_channel) + new_overwrite_dict = dict(self.text_overwrite) # Remove these keys because they were modified by the unsilence. del prev_overwrite_dict['send_messages'] @@ -491,3 +728,115 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): del new_overwrite_dict['add_reactions'] self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_preserved_other_overwrites_voice(self): + """Voice channel's other unrelated overwrites were not changed, including cache misses.""" + for overwrite_json in ('{"connect": true, "speak": true}', None): + with self.subTest(overwrite_json=overwrite_json): + self.cog.previous_overwrites.get.return_value = overwrite_json + + prev_overwrite_dict = dict(self.voice_overwrite) + await self.cog._unsilence(self.voice_channel) + new_overwrite_dict = dict(self.voice_overwrite) + + # Remove these keys because they were modified by the unsilence. + del prev_overwrite_dict['connect'] + del prev_overwrite_dict['speak'] + del new_overwrite_dict['connect'] + del new_overwrite_dict['speak'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_unsilence_role(self): + """Tests unsilence_wrapper applies permission to the correct role.""" + test_cases = ( + (MockTextChannel(), self.cog.bot.get_guild(Guild.id).default_role), + (MockVoiceChannel(), self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified)) + ) + + for channel, role in test_cases: + with self.subTest(channel=channel, role=role): + await self.cog._unsilence_wrapper(channel, MockContext()) + channel.overwrites_for.assert_called_with(role) + + +class SendMessageTests(unittest.IsolatedAsyncioTestCase): + """Unittests for the send message helper function.""" + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + + self.text_channels = [MockTextChannel() for _ in range(2)] + self.bot.get_channel.return_value = self.text_channels[1] + + self.voice_channel = MockVoiceChannel() + + async def test_send_to_channel(self): + """Tests a basic case for the send function.""" + message = "Test basic message." + await self.cog.send_message(message, *self.text_channels, alert_target=False) + + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_not_called() + + async def test_send_to_multiple_channels(self): + """Tests sending messages to two channels.""" + message = "Test basic message." + await self.cog.send_message(message, *self.text_channels, alert_target=True) + + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_awaited_once_with(message) + + async def test_duration_replacement(self): + """Tests that the channel name was set correctly for one target channel.""" + message = "Current. The following should be replaced: current channel." + await self.cog.send_message(message, *self.text_channels, alert_target=False) + + updated_message = message.replace("current channel", self.text_channels[0].mention) + self.text_channels[0].send.assert_awaited_once_with(updated_message) + self.text_channels[1].send.assert_not_called() + + async def test_name_replacement_multiple_channels(self): + """Tests that the channel name was set correctly for two channels.""" + message = "Current. The following should be replaced: current channel." + await self.cog.send_message(message, *self.text_channels, alert_target=True) + + updated_message = message.replace("current channel", self.text_channels[0].mention) + self.text_channels[0].send.assert_awaited_once_with(updated_message) + self.text_channels[1].send.assert_awaited_once_with(message) + + async def test_silence_voice(self): + """Tests that the correct message was sent when a voice channel is muted without alerting.""" + message = "This should show up just here." + await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=False) + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_not_called() + + async def test_silence_voice_alert(self): + """Tests that the correct message was sent when a voice channel is muted with alerts.""" + with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: + mock_voice_channels.get.return_value = self.text_channels[1].id + + message = "This should show up as current channel." + await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=True) + + updated_message = message.replace("current channel", self.voice_channel.mention) + self.text_channels[0].send.assert_awaited_once_with(updated_message) + self.text_channels[1].send.assert_awaited_once_with(updated_message) + + mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + + async def test_silence_voice_sibling_channel(self): + """Tests silencing a voice channel from the related text channel.""" + with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: + mock_voice_channels.get.return_value = self.text_channels[1].id + + message = "This should show up as current channel." + await self.cog.send_message(message, self.text_channels[1], self.voice_channel, alert_target=True) + + updated_message = message.replace("current channel", self.voice_channel.mention) + self.text_channels[1].send.assert_awaited_once_with(updated_message) + + mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + self.bot.get_channel.assert_called_once_with(self.text_channels[1].id) diff --git a/tests/helpers.py b/tests/helpers.py index e3dc5fe5b..86cc635f8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module - for logger in logging.Logger.manager.loggerDict.values(): # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ - spec_set = channel_instance + spec_set = text_channel_instance + + def __init__(self, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): + """ + A MagicMock subclass to mock VoiceChannel objects. + + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For + more information, see the `MockGuild` docstring. + """ + spec_set = voice_channel_instance def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} |