aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/constants.py10
-rw-r--r--bot/exts/moderation/silence.py287
-rw-r--r--config-default.yml14
-rw-r--r--tests/bot/exts/moderation/test_silence.py505
-rw-r--r--tests/helpers.py25
5 files changed, 696 insertions, 145 deletions
diff --git a/bot/constants.py b/bot/constants.py
index 7b2a38079..0c602f19b 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -449,15 +449,17 @@ class Channels(metaclass=YAMLGetter):
staff_announcements: int
admins_voice: int
+ code_help_voice_0: int
code_help_voice_1: int
- code_help_voice_2: int
- general_voice: int
+ general_voice_0: int
+ general_voice_1: int
staff_voice: int
+ code_help_chat_0: int
code_help_chat_1: int
- code_help_chat_2: int
staff_voice_chat: int
- voice_chat: int
+ voice_chat_0: int
+ voice_chat_1: int
big_brother_logs: int
talent_pool: int
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index 2a7ca932e..616dfbefb 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -2,35 +2,44 @@ import json
import logging
from contextlib import suppress
from datetime import datetime, timedelta, timezone
-from operator import attrgetter
-from typing import Optional
+from typing import Optional, OrderedDict, Union
from async_rediscache import RedisCache
-from discord import TextChannel
+from discord import Guild, PermissionOverwrite, TextChannel, VoiceChannel
from discord.ext import commands, tasks
from discord.ext.commands import Context
+from bot import constants
from bot.bot import Bot
-from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles
from bot.converters import HushDurationConverter
-from bot.utils.lock import LockedResourceError, lock_arg
+from bot.utils.lock import LockedResourceError, lock, lock_arg
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
LOCK_NAMESPACE = "silence"
-MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced."
-MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely."
-MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)."
+MSG_SILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel is already silenced."
+MSG_SILENCE_PERMANENT = f"{constants.Emojis.check_mark} silenced current channel indefinitely."
+MSG_SILENCE_SUCCESS = f"{constants.Emojis.check_mark} silenced current channel for {{duration}} minute(s)."
-MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced."
+MSG_UNSILENCE_FAIL = f"{constants.Emojis.cross_mark} current channel was not silenced."
MSG_UNSILENCE_MANUAL = (
- f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were "
+ f"{constants.Emojis.cross_mark} current channel was not unsilenced because the current overwrites were "
f"set manually or the cache was prematurely cleared. "
f"Please edit the overwrites manually to unsilence."
)
-MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel."
+MSG_UNSILENCE_SUCCESS = f"{constants.Emojis.check_mark} unsilenced current channel."
+
+TextOrVoiceChannel = Union[TextChannel, VoiceChannel]
+
+VOICE_CHANNELS = {
+ constants.Channels.code_help_voice_0: constants.Channels.code_help_chat_0,
+ constants.Channels.code_help_voice_1: constants.Channels.code_help_chat_1,
+ constants.Channels.general_voice_0: constants.Channels.voice_chat_0,
+ constants.Channels.general_voice_1: constants.Channels.voice_chat_1,
+ constants.Channels.staff_voice: constants.Channels.staff_voice_chat,
+}
class SilenceNotifier(tasks.Loop):
@@ -41,7 +50,7 @@ class SilenceNotifier(tasks.Loop):
self._silenced_channels = {}
self._alert_channel = alert_channel
- def add_channel(self, channel: TextChannel) -> None:
+ def add_channel(self, channel: TextOrVoiceChannel) -> None:
"""Add channel to `_silenced_channels` and start loop if not launched."""
if not self._silenced_channels:
self.start()
@@ -68,7 +77,19 @@ class SilenceNotifier(tasks.Loop):
f"{channel.mention} for {(self._current_loop-start)//60} min"
for channel, start in self._silenced_channels.items()
)
- await self._alert_channel.send(f"<@&{Roles.moderators}> currently silenced channels: {channels_text}")
+ await self._alert_channel.send(
+ f"<@&{constants.Roles.moderators}> currently silenced channels: {channels_text}"
+ )
+
+
+async def _select_lock_channel(args: OrderedDict[str, any]) -> TextOrVoiceChannel:
+ """Passes the channel to be silenced to the resource lock."""
+ channel = args["channel"]
+ if channel is not None:
+ return channel
+
+ else:
+ return args["ctx"].channel
class Silence(commands.Cog):
@@ -92,88 +113,230 @@ class Silence(commands.Cog):
"""Set instance attributes once the guild is available and reschedule unsilences."""
await self.bot.wait_until_guild_available()
- guild = self.bot.get_guild(Guild.id)
+ guild = self.bot.get_guild(constants.Guild.id)
+
self._everyone_role = guild.default_role
- self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts)
- self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log))
+ self._verified_voice_role = guild.get_role(constants.Roles.voice_verified)
+
+ self._mod_alerts_channel = self.bot.get_channel(constants.Channels.mod_alerts)
+
+ self.notifier = SilenceNotifier(self.bot.get_channel(constants.Channels.mod_log))
await self._reschedule()
+ async def send_message(
+ self,
+ message: str,
+ source_channel: TextChannel,
+ target_channel: TextOrVoiceChannel,
+ *, alert_target: bool = False
+ ) -> None:
+ """Helper function to send message confirmation to `source_channel`, and notification to `target_channel`."""
+ # Reply to invocation channel
+ source_reply = message
+ if source_channel != target_channel:
+ source_reply = source_reply.replace("current channel", target_channel.mention)
+ await source_channel.send(source_reply)
+
+ # Reply to target channel
+ if alert_target:
+ if isinstance(target_channel, VoiceChannel):
+ voice_chat = self.bot.get_channel(VOICE_CHANNELS.get(target_channel.id))
+ if voice_chat and source_channel != voice_chat:
+ await voice_chat.send(message.replace("current channel", target_channel.mention))
+
+ elif source_channel != target_channel:
+ await target_channel.send(message)
+
@commands.command(aliases=("hush",))
- @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True)
- async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None:
+ @lock(LOCK_NAMESPACE, _select_lock_channel, raise_error=True)
+ async def silence(
+ self,
+ ctx: Context,
+ duration: HushDurationConverter = 10,
+ channel: TextOrVoiceChannel = None,
+ *, kick: bool = False
+ ) -> None:
"""
Silence the current channel for `duration` minutes or `forever`.
Duration is capped at 15 minutes, passing forever makes the silence indefinite.
Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start.
+
+ Passing a voice channel will attempt to move members out of the channel and back to force sync permissions.
+ If `kick` is True, members will not be added back to the voice channel, and members will be unable to rejoin.
"""
await self._init_task
-
- channel_info = f"#{ctx.channel} ({ctx.channel.id})"
+ if channel is None:
+ channel = ctx.channel
+ channel_info = f"#{channel} ({channel.id})"
log.debug(f"{ctx.author} is silencing channel {channel_info}.")
- if not await self._set_silence_overwrites(ctx.channel):
+ if not await self._set_silence_overwrites(channel, kick=kick):
log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.")
- await ctx.send(MSG_SILENCE_FAIL)
+ await self.send_message(MSG_SILENCE_FAIL, ctx.channel, channel, alert_target=False)
return
- await self._schedule_unsilence(ctx, duration)
+ if isinstance(channel, VoiceChannel):
+ if kick:
+ await self._kick_voice_members(channel)
+ else:
+ await self._force_voice_sync(channel)
+
+ await self._schedule_unsilence(ctx, channel, duration)
if duration is None:
- self.notifier.add_channel(ctx.channel)
+ self.notifier.add_channel(channel)
log.info(f"Silenced {channel_info} indefinitely.")
- await ctx.send(MSG_SILENCE_PERMANENT)
+ await self.send_message(MSG_SILENCE_PERMANENT, ctx.channel, channel, alert_target=True)
+
else:
log.info(f"Silenced {channel_info} for {duration} minute(s).")
- await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration))
+ formatted_message = MSG_SILENCE_SUCCESS.format(duration=duration)
+ await self.send_message(formatted_message, ctx.channel, channel, alert_target=True)
@commands.command(aliases=("unhush",))
- async def unsilence(self, ctx: Context) -> None:
+ async def unsilence(self, ctx: Context, *, channel: TextOrVoiceChannel = None) -> None:
"""
- Unsilence the current channel.
+ Unsilence the given channel if given, else the current one.
If the channel was silenced indefinitely, notifications for the channel will stop.
"""
await self._init_task
- log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.")
- await self._unsilence_wrapper(ctx.channel)
+ if channel is None:
+ channel = ctx.channel
+ log.debug(f"Unsilencing channel #{channel} from {ctx.author}'s command.")
+ await self._unsilence_wrapper(channel, ctx)
@lock_arg(LOCK_NAMESPACE, "channel", raise_error=True)
- async def _unsilence_wrapper(self, channel: TextChannel) -> None:
- """Unsilence `channel` and send a success/failure message."""
+ async def _unsilence_wrapper(self, channel: TextOrVoiceChannel, ctx: Optional[Context] = None) -> None:
+ """
+ Unsilence `channel` and send a success/failure message to ctx.channel.
+
+ If ctx is None or not passed, `channel` is used in its place.
+ If `channel` and ctx.channel are the same, only one message is sent.
+ """
+ msg_channel = channel
+ if ctx is not None:
+ msg_channel = ctx.channel
+
if not await self._unsilence(channel):
- overwrite = channel.overwrites_for(self._everyone_role)
- if overwrite.send_messages is False or overwrite.add_reactions is False:
- await channel.send(MSG_UNSILENCE_MANUAL)
+ if isinstance(channel, VoiceChannel):
+ overwrite = channel.overwrites_for(self._verified_voice_role)
+ manual = overwrite.speak is False
else:
- await channel.send(MSG_UNSILENCE_FAIL)
+ overwrite = channel.overwrites_for(self._everyone_role)
+ manual = overwrite.send_messages is False or overwrite.add_reactions is False
+
+ # Send fail message to muted channel or voice chat channel, and invocation channel
+ if manual:
+ await self.send_message(MSG_UNSILENCE_MANUAL, msg_channel, channel, alert_target=False)
+ else:
+ await self.send_message(MSG_UNSILENCE_FAIL, msg_channel, channel, alert_target=False)
+
else:
- await channel.send(MSG_UNSILENCE_SUCCESS)
+ await self.send_message(MSG_UNSILENCE_SUCCESS, msg_channel, channel, alert_target=True)
- async def _set_silence_overwrites(self, channel: TextChannel) -> bool:
+ async def _set_silence_overwrites(self, channel: TextOrVoiceChannel, *, kick: bool = False) -> bool:
"""Set silence permission overwrites for `channel` and return True if successful."""
- overwrite = channel.overwrites_for(self._everyone_role)
- prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions)
+ # Get the original channel overwrites
+ if isinstance(channel, TextChannel):
+ role = self._everyone_role
+ overwrite = channel.overwrites_for(role)
+ prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions)
+ else:
+ role = self._verified_voice_role
+ overwrite = channel.overwrites_for(role)
+ prev_overwrites = dict(speak=overwrite.speak)
+ if kick:
+ prev_overwrites.update(connect=overwrite.connect)
+
+ # Stop if channel was already silenced
if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):
return False
- overwrite.update(send_messages=False, add_reactions=False)
- await channel.set_permissions(self._everyone_role, overwrite=overwrite)
+ # Set new permissions, store
+ overwrite.update(**dict.fromkeys(prev_overwrites, False))
+ await channel.set_permissions(role, overwrite=overwrite)
await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites))
return True
- async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None:
+ @staticmethod
+ async def _get_afk_channel(guild: Guild) -> VoiceChannel:
+ """Get a guild's AFK channel, or create one if it does not exist."""
+ afk_channel = guild.afk_channel
+
+ if afk_channel is None:
+ overwrites = {
+ guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False)
+ }
+ afk_channel = await guild.create_voice_channel("mute-temp", overwrites=overwrites)
+ log.info(f"Failed to get afk-channel, created temporary channel #{afk_channel} ({afk_channel.id})")
+
+ return afk_channel
+
+ @staticmethod
+ async def _kick_voice_members(channel: VoiceChannel) -> None:
+ """Remove all non-staff members from a voice channel."""
+ log.debug(f"Removing all non staff members from #{channel.name} ({channel.id}).")
+
+ for member in channel.members:
+ # Skip staff
+ if any(role.id in constants.MODERATION_ROLES for role in member.roles):
+ continue
+
+ try:
+ await member.move_to(None, reason="Kicking member from voice channel.")
+ log.debug(f"Kicked {member.name} from voice channel.")
+ except Exception as e:
+ log.debug(f"Failed to move {member.name}. Reason: {e}")
+ continue
+
+ log.debug("Removed all members.")
+
+ async def _force_voice_sync(self, channel: VoiceChannel) -> None:
+ """
+ Move all non-staff members from `channel` to a temporary channel and back to force toggle role mute.
+
+ Permission modification has to happen before this function.
+ """
+ # Obtain temporary channel
+ delete_channel = channel.guild.afk_channel is None
+ afk_channel = await self._get_afk_channel(channel.guild)
+
+ try:
+ # Move all members to temporary channel and back
+ for member in channel.members:
+ # Skip staff
+ if any(role.id in constants.MODERATION_ROLES for role in member.roles):
+ continue
+
+ try:
+ await member.move_to(afk_channel, reason="Muting VC member.")
+ log.debug(f"Moved {member.name} to afk channel.")
+
+ await member.move_to(channel, reason="Muting VC member.")
+ log.debug(f"Moved {member.name} to original voice channel.")
+ except Exception as e:
+ log.debug(f"Failed to move {member.name}. Reason: {e}")
+ continue
+
+ finally:
+ # Delete VC channel if it was created.
+ if delete_channel:
+ await afk_channel.delete(reason="Deleting temporary mute channel.")
+
+ async def _schedule_unsilence(self, ctx: Context, channel: TextOrVoiceChannel, duration: Optional[int]) -> None:
"""Schedule `ctx.channel` to be unsilenced if `duration` is not None."""
if duration is None:
- await self.unsilence_timestamps.set(ctx.channel.id, -1)
+ await self.unsilence_timestamps.set(channel.id, -1)
else:
- self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))
+ self.scheduler.schedule_later(duration * 60, channel.id, ctx.invoke(self.unsilence, channel=channel))
unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration)
- await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp())
+ await self.unsilence_timestamps.set(channel.id, unsilence_time.timestamp())
- async def _unsilence(self, channel: TextChannel) -> bool:
+ async def _unsilence(self, channel: TextOrVoiceChannel) -> bool:
"""
Unsilence `channel`.
@@ -183,19 +346,34 @@ class Silence(commands.Cog):
Return `True` if channel permissions were changed, `False` otherwise.
"""
+ # Get stored overwrites, and return if channel is unsilenced
prev_overwrites = await self.previous_overwrites.get(channel.id)
if channel.id not in self.scheduler and prev_overwrites is None:
log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")
return False
- overwrite = channel.overwrites_for(self._everyone_role)
+ # Select the role based on channel type, and get current overwrites
+ if isinstance(channel, TextChannel):
+ role = self._everyone_role
+ overwrite = channel.overwrites_for(role)
+ permissions = "`Send Messages` and `Add Reactions`"
+ else:
+ role = self._verified_voice_role
+ overwrite = channel.overwrites_for(role)
+ permissions = "`Speak` and `Connect`"
+
+ # Check if old overwrites were not stored
if prev_overwrites is None:
log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.")
- overwrite.update(send_messages=None, add_reactions=None)
+ overwrite.update(send_messages=None, add_reactions=None, speak=None, connect=None)
else:
overwrite.update(**json.loads(prev_overwrites))
- await channel.set_permissions(self._everyone_role, overwrite=overwrite)
+ # Update Permissions
+ await channel.set_permissions(role, overwrite=overwrite)
+ if isinstance(channel, VoiceChannel):
+ await self._force_voice_sync(channel)
+
log.info(f"Unsilenced channel #{channel} ({channel.id}).")
self.scheduler.cancel(channel.id)
@@ -203,11 +381,12 @@ class Silence(commands.Cog):
await self.previous_overwrites.delete(channel.id)
await self.unsilence_timestamps.delete(channel.id)
+ # Alert Admin team if old overwrites were not available
if prev_overwrites is None:
await self._mod_alerts_channel.send(
- f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing "
- f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` "
- f"overwrites for {self._everyone_role.mention} are at their desired values."
+ f"<@&{constants.Roles.admins}> Restored overwrites with default values after unsilencing "
+ f"{channel.mention}. Please check that the {permissions} "
+ f"overwrites for {role.mention} are at their desired values."
)
return True
@@ -247,7 +426,7 @@ class Silence(commands.Cog):
# This cannot be static (must have a __func__ attribute).
async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
- return await commands.has_any_role(*MODERATION_ROLES).predicate(ctx)
+ return await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx)
def setup(bot: Bot) -> None:
diff --git a/config-default.yml b/config-default.yml
index 46475f845..b9f6b40ac 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -210,16 +210,18 @@ guild:
# Voice Channels
admins_voice: &ADMINS_VOICE 500734494840717332
- code_help_voice_1: 751592231726481530
- code_help_voice_2: 764232549840846858
- general_voice: 751591688538947646
+ code_help_voice_0: 751592231726481530
+ code_help_voice_1: 764232549840846858
+ general_voice_0: 751591688538947646
+ general_voice_1: 799641437645701151
staff_voice: &STAFF_VOICE 412375055910043655
# Voice Chat
- code_help_chat_1: 755154969761677312
- code_help_chat_2: 766330079135268884
+ code_help_chat_0: 755154969761677312
+ code_help_chat_1: 766330079135268884
staff_voice_chat: 541638762007101470
- voice_chat: 412357430186344448
+ voice_chat_0: 412357430186344448
+ voice_chat_1: 799647045886541885
# Watch
big_brother_logs: &BB_LOGS 468507907357409333
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index fa5fc9e81..729b28412 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -1,15 +1,25 @@
import asyncio
import unittest
from datetime import datetime, timezone
+from typing import List, Tuple
from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from async_rediscache import RedisSession
from discord import PermissionOverwrite
-from bot.constants import Channels, Guild, Roles
+from bot.constants import Channels, Guild, MODERATION_ROLES, Roles
from bot.exts.moderation import silence
-from tests.helpers import MockBot, MockContext, MockTextChannel, autospec
+from tests.helpers import (
+ MockBot,
+ MockContext,
+ MockGuild,
+ MockMember,
+ MockRole,
+ MockTextChannel,
+ MockVoiceChannel,
+ autospec
+)
redis_session = None
redis_loop = asyncio.get_event_loop()
@@ -149,7 +159,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(self.cog._init_task.cancelled())
@autospec("discord.ext.commands", "has_any_role")
- @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3))
+ @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3))
async def test_cog_check(self, role_check):
"""Role check was called with `MODERATION_ROLES`"""
ctx = MockContext()
@@ -159,6 +169,95 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
role_check.assert_called_once_with(*(1, 2, 3))
role_check.return_value.predicate.assert_awaited_once_with(ctx)
+ async def test_force_voice_sync(self):
+ """Tests the _force_voice_sync helper function."""
+ await self.cog._async_init()
+
+ # Create a regular member, and one member for each of the moderation roles
+ moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES]
+ members = [MockMember(), *moderation_members]
+
+ channel = MockVoiceChannel(members=members)
+
+ await self.cog._force_voice_sync(channel)
+ for member in members:
+ if member in moderation_members:
+ member.move_to.assert_not_called()
+ else:
+ self.assertEqual(member.move_to.call_count, 2)
+ calls = member.move_to.call_args_list
+
+ # Tests that the member was moved to the afk channel, and back.
+ self.assertEqual((channel.guild.afk_channel,), calls[0].args)
+ self.assertEqual((channel,), calls[1].args)
+
+ async def test_force_voice_sync_no_channel(self):
+ """Test to ensure _force_voice_sync can create its own voice channel if one is not available."""
+ await self.cog._async_init()
+
+ channel = MockVoiceChannel(guild=MockGuild(afk_channel=None))
+ new_channel = MockVoiceChannel(delete=AsyncMock())
+ channel.guild.create_voice_channel.return_value = new_channel
+
+ await self.cog._force_voice_sync(channel)
+
+ # Check channel creation
+ overwrites = {
+ channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False)
+ }
+ channel.guild.create_voice_channel.assert_awaited_once_with("mute-temp", overwrites=overwrites)
+
+ # Check bot deleted channel
+ new_channel.delete.assert_awaited_once()
+
+ async def test_voice_kick(self):
+ """Test to ensure kick function can remove all members from a voice channel."""
+ await self.cog._async_init()
+
+ # Create a regular member, and one member for each of the moderation roles
+ moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES]
+ members = [MockMember(), *moderation_members]
+
+ channel = MockVoiceChannel(members=members)
+ await self.cog._kick_voice_members(channel)
+
+ for member in members:
+ if member in moderation_members:
+ member.move_to.assert_not_called()
+ else:
+ self.assertEqual((None,), member.move_to.call_args_list[0].args)
+
+ @staticmethod
+ def create_erroneous_members() -> Tuple[List[MockMember], List[MockMember]]:
+ """
+ Helper method to generate a list of members that error out on move_to call.
+
+ Returns the list of erroneous members,
+ as well as a list of regular and erroneous members combined, in that order.
+ """
+ erroneous_member = MockMember(move_to=AsyncMock(side_effect=Exception()))
+ members = [MockMember(), erroneous_member]
+
+ return erroneous_member, members
+
+ async def test_kick_move_to_error(self):
+ """Test to ensure move_to gets called on all members during kick, even if some fail."""
+ await self.cog._async_init()
+ _, members = self.create_erroneous_members()
+
+ await self.cog._kick_voice_members(MockVoiceChannel(members=members))
+ for member in members:
+ member.move_to.assert_awaited_once()
+
+ async def test_sync_move_to_error(self):
+ """Test to ensure move_to gets called on all members during sync, even if some fail."""
+ await self.cog._async_init()
+ failing_member, members = self.create_erroneous_members()
+
+ await self.cog._force_voice_sync(MockVoiceChannel(members=members))
+ for member in members:
+ self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2)
+
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class RescheduleTests(unittest.IsolatedAsyncioTestCase):
@@ -235,6 +334,16 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase):
self.cog.notifier.add_channel.assert_not_called()
+def voice_sync_helper(function):
+ """Helper wrapper to test the sync and kick functions for voice channels."""
+ @autospec(silence.Silence, "_force_voice_sync", "_kick_voice_members", "_set_silence_overwrites")
+ async def inner(self, sync, kick, overwrites):
+ overwrites.return_value = True
+ await function(self, MockContext(), sync, kick)
+
+ return inner
+
+
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
class SilenceTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the silence command and its related helper methods."""
@@ -242,7 +351,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
@autospec(silence.Silence, "_reschedule", pass_mocks=False)
@autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
def setUp(self) -> None:
- self.bot = MockBot()
+ self.bot = MockBot(get_channel=lambda _: MockTextChannel())
self.cog = silence.Silence(self.bot)
self.cog._init_task = asyncio.Future()
self.cog._init_task.set_result(None)
@@ -252,56 +361,126 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
asyncio.run(self.cog._async_init()) # Populate instance attributes.
- self.channel = MockTextChannel()
- self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False)
- self.channel.overwrites_for.return_value = self.overwrite
+ self.text_channel = MockTextChannel()
+ self.text_overwrite = PermissionOverwrite(send_messages=True, add_reactions=False)
+ self.text_channel.overwrites_for.return_value = self.text_overwrite
+
+ self.voice_channel = MockVoiceChannel()
+ self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)
+ self.voice_channel.overwrites_for.return_value = self.voice_overwrite
async def test_sent_correct_message(self):
- """Appropriate failure/success message was sent by the command."""
+ """Appropriate failure/success message was sent by the command to the correct channel."""
+ # The following test tuples are made up of:
+ # duration, expected message, and the success of the _set_silence_overwrites function
test_cases = (
(0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),
(None, silence.MSG_SILENCE_PERMANENT, True,),
(5, silence.MSG_SILENCE_FAIL, False,),
)
+
for duration, message, was_silenced in test_cases:
- ctx = MockContext()
with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced):
- with self.subTest(was_silenced=was_silenced, message=message, duration=duration):
- await self.cog.silence.callback(self.cog, ctx, duration)
- ctx.send.assert_called_once_with(message)
+ for target in [MockTextChannel(), MockVoiceChannel(), None]:
+ with self.subTest(was_silenced=was_silenced, target=target, message=message):
+ with mock.patch.object(self.cog, "send_message") as send_message:
+ ctx = MockContext()
+ await self.cog.silence.callback(self.cog, ctx, duration, target)
+ send_message.assert_called_once_with(
+ message,
+ ctx.channel,
+ target or ctx.channel,
+ alert_target=was_silenced
+ )
+
+ @voice_sync_helper
+ async def test_sync_called(self, ctx, sync, kick):
+ """Tests if silence command calls sync on a voice channel."""
+ channel = MockVoiceChannel()
+ await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False)
+
+ sync.assert_awaited_once_with(self.cog, channel)
+ kick.assert_not_called()
+
+ @voice_sync_helper
+ async def test_kick_called(self, ctx, sync, kick):
+ """Tests if silence command calls kick on a voice channel."""
+ channel = MockVoiceChannel()
+ await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True)
+
+ kick.assert_awaited_once_with(channel)
+ sync.assert_not_called()
+
+ @voice_sync_helper
+ async def test_sync_not_called(self, ctx, sync, kick):
+ """Tests that silence command does not call sync on a text channel."""
+ channel = MockTextChannel()
+ await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=False)
+
+ sync.assert_not_called()
+ kick.assert_not_called()
+
+ @voice_sync_helper
+ async def test_kick_not_called(self, ctx, sync, kick):
+ """Tests that silence command does not call kick on a text channel."""
+ channel = MockTextChannel()
+ await self.cog.silence.callback(self.cog, ctx, 10, channel, kick=True)
+
+ sync.assert_not_called()
+ kick.assert_not_called()
async def test_skipped_already_silenced(self):
"""Permissions were not set and `False` was returned for an already silenced channel."""
subtests = (
- (False, PermissionOverwrite(send_messages=False, add_reactions=False)),
- (True, PermissionOverwrite(send_messages=True, add_reactions=True)),
- (True, PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (False, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (True, MockTextChannel(), PermissionOverwrite(send_messages=True, add_reactions=True)),
+ (True, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (False, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)),
+ (True, MockVoiceChannel(), PermissionOverwrite(connect=True, speak=True)),
+ (True, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)),
)
- for contains, overwrite in subtests:
- with self.subTest(contains=contains, overwrite=overwrite):
+ for contains, channel, overwrite in subtests:
+ with self.subTest(contains=contains, is_text=isinstance(channel, MockTextChannel), overwrite=overwrite):
self.cog.scheduler.__contains__.return_value = contains
- channel = MockTextChannel()
channel.overwrites_for.return_value = overwrite
self.assertFalse(await self.cog._set_silence_overwrites(channel))
channel.set_permissions.assert_not_called()
- async def test_silenced_channel(self):
+ async def test_silenced_text_channel(self):
"""Channel had `send_message` and `add_reactions` permissions revoked for verified role."""
- self.assertTrue(await self.cog._set_silence_overwrites(self.channel))
- self.assertFalse(self.overwrite.send_messages)
- self.assertFalse(self.overwrite.add_reactions)
- self.channel.set_permissions.assert_awaited_once_with(
+ self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel))
+ self.assertFalse(self.text_overwrite.send_messages)
+ self.assertFalse(self.text_overwrite.add_reactions)
+ self.text_channel.set_permissions.assert_awaited_once_with(
self.cog._everyone_role,
- overwrite=self.overwrite
+ overwrite=self.text_overwrite
+ )
+
+ async def test_silenced_voice_channel_speak(self):
+ """Channel had `speak` permissions revoked for verified role."""
+ self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel))
+ self.assertFalse(self.voice_overwrite.speak)
+ self.voice_channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_voice_role,
+ overwrite=self.voice_overwrite
)
- async def test_preserved_other_overwrites(self):
- """Channel's other unrelated overwrites were not changed."""
- prev_overwrite_dict = dict(self.overwrite)
- await self.cog._set_silence_overwrites(self.channel)
- new_overwrite_dict = dict(self.overwrite)
+ async def test_silenced_voice_channel_full(self):
+ """Channel had `speak` and `connect` permissions revoked for verified role."""
+ self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, kick=True))
+ self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect)
+ self.voice_channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_voice_role,
+ overwrite=self.voice_overwrite
+ )
+
+ async def test_preserved_other_overwrites_text(self):
+ """Channel's other unrelated overwrites were not changed for a text channel mute."""
+ prev_overwrite_dict = dict(self.text_overwrite)
+ await self.cog._set_silence_overwrites(self.text_channel)
+ new_overwrite_dict = dict(self.text_overwrite)
# Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.
del prev_overwrite_dict['send_messages']
@@ -311,6 +490,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+ async def test_preserved_other_overwrites_voice(self):
+ """Channel's other unrelated overwrites were not changed for a voice channel mute."""
+ prev_overwrite_dict = dict(self.voice_overwrite)
+ await self.cog._set_silence_overwrites(self.voice_channel)
+ new_overwrite_dict = dict(self.voice_overwrite)
+
+ # Remove 'connect' & 'speak' keys because they were changed by the method.
+ del prev_overwrite_dict['connect']
+ del prev_overwrite_dict['speak']
+ del new_overwrite_dict['connect']
+ del new_overwrite_dict['speak']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
async def test_temp_not_added_to_notifier(self):
"""Channel was not added to notifier if a duration was set for the silence."""
with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
@@ -332,8 +525,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_cached_previous_overwrites(self):
"""Channel's previous overwrites were cached."""
overwrite_json = '{"send_messages": true, "add_reactions": false}'
- await self.cog._set_silence_overwrites(self.channel)
- self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json)
+ await self.cog._set_silence_overwrites(self.text_channel)
+ self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json)
@autospec(silence, "datetime")
async def test_cached_unsilence_time(self, datetime_mock):
@@ -343,7 +536,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
timestamp = now_timestamp + duration * 60
datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc)
- ctx = MockContext(channel=self.channel)
+ ctx = MockContext(channel=self.text_channel)
await self.cog.silence.callback(self.cog, ctx, duration)
self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp)
@@ -351,23 +544,23 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
async def test_cached_indefinite_time(self):
"""A value of -1 was cached for a permanent silence."""
- ctx = MockContext(channel=self.channel)
+ ctx = MockContext(channel=self.text_channel)
await self.cog.silence.callback(self.cog, ctx, None)
self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
async def test_scheduled_task(self):
"""An unsilence task was scheduled."""
- ctx = MockContext(channel=self.channel, invoke=mock.MagicMock())
+ ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock())
await self.cog.silence.callback(self.cog, ctx, 5)
args = (300, ctx.channel.id, ctx.invoke.return_value)
self.cog.scheduler.schedule_later.assert_called_once_with(*args)
- ctx.invoke.assert_called_once_with(self.cog.unsilence)
+ ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel)
async def test_permanent_not_scheduled(self):
"""A task was not scheduled for a permanent silence."""
- ctx = MockContext(channel=self.channel)
+ ctx = MockContext(channel=self.text_channel)
await self.cog.silence.callback(self.cog, ctx, None)
self.cog.scheduler.schedule_later.assert_not_called()
@@ -391,9 +584,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
self.cog.scheduler.__contains__.return_value = True
overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
- self.channel = MockTextChannel()
- self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False)
- self.channel.overwrites_for.return_value = self.overwrite
+ self.text_channel = MockTextChannel()
+ self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False)
+ self.text_channel.overwrites_for.return_value = self.text_overwrite
+
+ self.voice_channel = MockVoiceChannel()
+ self.voice_overwrite = PermissionOverwrite(connect=True, speak=True)
+ self.voice_channel.overwrites_for.return_value = self.voice_overwrite
async def test_sent_correct_message(self):
"""Appropriate failure/success message was sent by the command."""
@@ -401,88 +598,128 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
test_cases = (
(True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite),
(False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite),
- (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, self.text_overwrite),
(False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),
(False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),
)
+
for was_unsilenced, message, overwrite in test_cases:
ctx = MockContext()
- with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite):
+
+ for target in [None, MockTextChannel()]:
+ ctx.channel.overwrites_for.return_value = overwrite
+ if target:
+ target.overwrites_for.return_value = overwrite
+
with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced):
- ctx.channel.overwrites_for.return_value = overwrite
- await self.cog.unsilence.callback(self.cog, ctx)
- ctx.channel.send.assert_called_once_with(message)
+ with mock.patch.object(self.cog, "send_message") as send_message:
+ with self.subTest(was_unsilenced=was_unsilenced, overwrite=overwrite, target=target):
+ await self.cog.unsilence.callback(self.cog, ctx, channel=target)
+
+ call_args = (message, ctx.channel, target or ctx.channel)
+ send_message.assert_awaited_once_with(*call_args, alert_target=was_unsilenced)
async def test_skipped_already_unsilenced(self):
"""Permissions were not set and `False` was returned for an already unsilenced channel."""
self.cog.scheduler.__contains__.return_value = False
self.cog.previous_overwrites.get.return_value = None
- channel = MockTextChannel()
- self.assertFalse(await self.cog._unsilence(channel))
- channel.set_permissions.assert_not_called()
+ for channel in (MockVoiceChannel(), MockTextChannel()):
+ with self.subTest(channel=channel):
+ self.assertFalse(await self.cog._unsilence(channel))
+ channel.set_permissions.assert_not_called()
- async def test_restored_overwrites(self):
- """Channel's `send_message` and `add_reactions` overwrites were restored."""
- await self.cog._unsilence(self.channel)
- self.channel.set_permissions.assert_awaited_once_with(
+ async def test_restored_overwrites_text(self):
+ """Text channel's `send_message` and `add_reactions` overwrites were restored."""
+ await self.cog._unsilence(self.text_channel)
+ self.text_channel.set_permissions.assert_awaited_once_with(
self.cog._everyone_role,
- overwrite=self.overwrite,
+ overwrite=self.text_overwrite,
+ )
+
+ # Recall that these values are determined by the fixture.
+ self.assertTrue(self.text_overwrite.send_messages)
+ self.assertFalse(self.text_overwrite.add_reactions)
+
+ async def test_restored_overwrites_voice(self):
+ """Voice channel's `connect` and `speak` overwrites were restored."""
+ await self.cog._unsilence(self.voice_channel)
+ self.voice_channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_voice_role,
+ overwrite=self.voice_overwrite,
)
# Recall that these values are determined by the fixture.
- self.assertTrue(self.overwrite.send_messages)
- self.assertFalse(self.overwrite.add_reactions)
+ self.assertTrue(self.voice_overwrite.connect)
+ self.assertTrue(self.voice_overwrite.speak)
- async def test_cache_miss_used_default_overwrites(self):
- """Both overwrites were set to None due previous values not being found in the cache."""
+ async def test_cache_miss_used_default_overwrites_text(self):
+ """Text overwrites were set to None due previous values not being found in the cache."""
self.cog.previous_overwrites.get.return_value = None
- await self.cog._unsilence(self.channel)
- self.channel.set_permissions.assert_awaited_once_with(
+ await self.cog._unsilence(self.text_channel)
+ self.text_channel.set_permissions.assert_awaited_once_with(
self.cog._everyone_role,
- overwrite=self.overwrite,
+ overwrite=self.text_overwrite,
)
- self.assertIsNone(self.overwrite.send_messages)
- self.assertIsNone(self.overwrite.add_reactions)
+ self.assertIsNone(self.text_overwrite.send_messages)
+ self.assertIsNone(self.text_overwrite.add_reactions)
- async def test_cache_miss_sent_mod_alert(self):
- """A message was sent to the mod alerts channel."""
+ async def test_cache_miss_used_default_overwrites_voice(self):
+ """Voice overwrites were set to None due previous values not being found in the cache."""
self.cog.previous_overwrites.get.return_value = None
- await self.cog._unsilence(self.channel)
+ await self.cog._unsilence(self.voice_channel)
+ self.voice_channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_voice_role,
+ overwrite=self.voice_overwrite,
+ )
+
+ self.assertIsNone(self.voice_overwrite.connect)
+ self.assertIsNone(self.voice_overwrite.speak)
+
+ async def test_cache_miss_sent_mod_alert_text(self):
+ """A message was sent to the mod alerts channel upon muting a text channel."""
+ self.cog.previous_overwrites.get.return_value = None
+ await self.cog._unsilence(self.text_channel)
+ self.cog._mod_alerts_channel.send.assert_awaited_once()
+
+ async def test_cache_miss_sent_mod_alert_voice(self):
+ """A message was sent to the mod alerts channel upon muting a voice channel."""
+ self.cog.previous_overwrites.get.return_value = None
+ await self.cog._unsilence(MockVoiceChannel())
self.cog._mod_alerts_channel.send.assert_awaited_once()
async def test_removed_notifier(self):
"""Channel was removed from `notifier`."""
- await self.cog._unsilence(self.channel)
- self.cog.notifier.remove_channel.assert_called_once_with(self.channel)
+ await self.cog._unsilence(self.text_channel)
+ self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel)
async def test_deleted_cached_overwrite(self):
"""Channel was deleted from the overwrites cache."""
- await self.cog._unsilence(self.channel)
- self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id)
+ await self.cog._unsilence(self.text_channel)
+ self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id)
async def test_deleted_cached_time(self):
"""Channel was deleted from the timestamp cache."""
- await self.cog._unsilence(self.channel)
- self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id)
+ await self.cog._unsilence(self.text_channel)
+ self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id)
async def test_cancelled_task(self):
"""The scheduled unsilence task should be cancelled."""
- await self.cog._unsilence(self.channel)
- self.cog.scheduler.cancel.assert_called_once_with(self.channel.id)
+ await self.cog._unsilence(self.text_channel)
+ self.cog.scheduler.cancel.assert_called_once_with(self.text_channel.id)
- async def test_preserved_other_overwrites(self):
- """Channel's other unrelated overwrites were not changed, including cache misses."""
+ async def test_preserved_other_overwrites_text(self):
+ """Text channel's other unrelated overwrites were not changed, including cache misses."""
for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):
with self.subTest(overwrite_json=overwrite_json):
self.cog.previous_overwrites.get.return_value = overwrite_json
- prev_overwrite_dict = dict(self.overwrite)
- await self.cog._unsilence(self.channel)
- new_overwrite_dict = dict(self.overwrite)
+ prev_overwrite_dict = dict(self.text_overwrite)
+ await self.cog._unsilence(self.text_channel)
+ new_overwrite_dict = dict(self.text_overwrite)
# Remove these keys because they were modified by the unsilence.
del prev_overwrite_dict['send_messages']
@@ -491,3 +728,115 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
del new_overwrite_dict['add_reactions']
self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
+ async def test_preserved_other_overwrites_voice(self):
+ """Voice channel's other unrelated overwrites were not changed, including cache misses."""
+ for overwrite_json in ('{"connect": true, "speak": true}', None):
+ with self.subTest(overwrite_json=overwrite_json):
+ self.cog.previous_overwrites.get.return_value = overwrite_json
+
+ prev_overwrite_dict = dict(self.voice_overwrite)
+ await self.cog._unsilence(self.voice_channel)
+ new_overwrite_dict = dict(self.voice_overwrite)
+
+ # Remove these keys because they were modified by the unsilence.
+ del prev_overwrite_dict['connect']
+ del prev_overwrite_dict['speak']
+ del new_overwrite_dict['connect']
+ del new_overwrite_dict['speak']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
+ async def test_unsilence_role(self):
+ """Tests unsilence_wrapper applies permission to the correct role."""
+ test_cases = (
+ (MockTextChannel(), self.cog.bot.get_guild(Guild.id).default_role),
+ (MockVoiceChannel(), self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified))
+ )
+
+ for channel, role in test_cases:
+ with self.subTest(channel=channel, role=role):
+ await self.cog._unsilence_wrapper(channel, MockContext())
+ channel.overwrites_for.assert_called_with(role)
+
+
+class SendMessageTests(unittest.IsolatedAsyncioTestCase):
+ """Unittests for the send message helper function."""
+
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+
+ self.text_channels = [MockTextChannel() for _ in range(2)]
+ self.bot.get_channel.return_value = self.text_channels[1]
+
+ self.voice_channel = MockVoiceChannel()
+
+ async def test_send_to_channel(self):
+ """Tests a basic case for the send function."""
+ message = "Test basic message."
+ await self.cog.send_message(message, *self.text_channels, alert_target=False)
+
+ self.text_channels[0].send.assert_awaited_once_with(message)
+ self.text_channels[1].send.assert_not_called()
+
+ async def test_send_to_multiple_channels(self):
+ """Tests sending messages to two channels."""
+ message = "Test basic message."
+ await self.cog.send_message(message, *self.text_channels, alert_target=True)
+
+ self.text_channels[0].send.assert_awaited_once_with(message)
+ self.text_channels[1].send.assert_awaited_once_with(message)
+
+ async def test_duration_replacement(self):
+ """Tests that the channel name was set correctly for one target channel."""
+ message = "Current. The following should be replaced: current channel."
+ await self.cog.send_message(message, *self.text_channels, alert_target=False)
+
+ updated_message = message.replace("current channel", self.text_channels[0].mention)
+ self.text_channels[0].send.assert_awaited_once_with(updated_message)
+ self.text_channels[1].send.assert_not_called()
+
+ async def test_name_replacement_multiple_channels(self):
+ """Tests that the channel name was set correctly for two channels."""
+ message = "Current. The following should be replaced: current channel."
+ await self.cog.send_message(message, *self.text_channels, alert_target=True)
+
+ updated_message = message.replace("current channel", self.text_channels[0].mention)
+ self.text_channels[0].send.assert_awaited_once_with(updated_message)
+ self.text_channels[1].send.assert_awaited_once_with(message)
+
+ async def test_silence_voice(self):
+ """Tests that the correct message was sent when a voice channel is muted without alerting."""
+ message = "This should show up just here."
+ await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=False)
+ self.text_channels[0].send.assert_awaited_once_with(message)
+ self.text_channels[1].send.assert_not_called()
+
+ async def test_silence_voice_alert(self):
+ """Tests that the correct message was sent when a voice channel is muted with alerts."""
+ with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels:
+ mock_voice_channels.get.return_value = self.text_channels[1].id
+
+ message = "This should show up as current channel."
+ await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=True)
+
+ updated_message = message.replace("current channel", self.voice_channel.mention)
+ self.text_channels[0].send.assert_awaited_once_with(updated_message)
+ self.text_channels[1].send.assert_awaited_once_with(updated_message)
+
+ mock_voice_channels.get.assert_called_once_with(self.voice_channel.id)
+
+ async def test_silence_voice_sibling_channel(self):
+ """Tests silencing a voice channel from the related text channel."""
+ with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels:
+ mock_voice_channels.get.return_value = self.text_channels[1].id
+
+ message = "This should show up as current channel."
+ await self.cog.send_message(message, self.text_channels[1], self.voice_channel, alert_target=True)
+
+ updated_message = message.replace("current channel", self.voice_channel.mention)
+ self.text_channels[1].send.assert_awaited_once_with(updated_message)
+
+ mock_voice_channels.get.assert_called_once_with(self.voice_channel.id)
+ self.bot.get_channel.assert_called_once_with(self.text_channels[1].id)
diff --git a/tests/helpers.py b/tests/helpers.py
index e3dc5fe5b..86cc635f8 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
from tests._autospec import autospec # noqa: F401 other modules import it via this module
-
for logger in logging.Logger.manager.loggerDict.values():
# Set all loggers to CRITICAL by default to prevent screen clutter during testing
@@ -320,7 +319,10 @@ channel_data = {
}
state = unittest.mock.MagicMock()
guild = unittest.mock.MagicMock()
-channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data)
+text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data)
+
+channel_data["type"] = "VoiceChannel"
+voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data)
class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
@@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
- spec_set = channel_instance
+ spec_set = text_channel_instance
+
+ def __init__(self, **kwargs) -> None:
+ default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
+
+ if 'mention' not in kwargs:
+ self.mention = f"#{self.name}"
+
+
+class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A MagicMock subclass to mock VoiceChannel objects.
+
+ Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ spec_set = voice_channel_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}