diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 84 | ||||
-rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 600 | ||||
-rw-r--r-- | tests/bot/test_converters.py | 2 | ||||
-rw-r--r-- | tests/helpers.py | 25 |
4 files changed, 613 insertions, 98 deletions
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index bd4fb5942..37e8108fc 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.backend.error_handler import ErrorHandler, setup from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole +from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -226,8 +226,8 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError()) self.assertFalse(await self.cog.try_silence(self.ctx)) - async def test_try_silence_silencing(self): - """Should run silence command with correct arguments.""" + async def test_try_silence_silence_duration(self): + """Should run silence command with correct duration argument.""" self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") @@ -238,21 +238,85 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(await self.cog.try_silence(self.ctx)) self.ctx.invoke.assert_awaited_once_with( self.bot.get_command.return_value, - duration=min(case.count("h")*2, 15) + duration_or_channel=None, + duration=min(case.count("h")*2, 15), + kick=False ) + async def test_try_silence_silence_arguments(self): + """Should run silence with the correct channel, duration, and kick arguments.""" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + + test_cases = ( + (MockTextChannel(), None), # None represents the case when no argument is passed + (MockTextChannel(), False), + (MockTextChannel(), True) + ) + + for channel, kick in test_cases: + with self.subTest(kick=kick, channel=channel): + self.ctx.reset_mock() + self.ctx.invoked_with = "shh" + + self.ctx.message.content = f"!shh {channel.name} {kick if kick is not None else ''}" + self.ctx.guild.text_channels = [channel] + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with( + self.bot.get_command.return_value, + duration_or_channel=channel, + duration=4, + kick=(kick if kick is not None else False) + ) + + async def test_try_silence_silence_message(self): + """If the words after the command could not be converted to a channel, None should be passed as channel.""" + self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + self.ctx.invoked_with = "shh" + self.ctx.message.content = "!shh not_a_channel true" + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with( + self.bot.get_command.return_value, + duration_or_channel=None, + duration=4, + kick=False + ) + async def test_try_silence_unsilence(self): - """Should call unsilence command.""" + """Should call unsilence command with correct duration and channel arguments.""" self.silence.silence.can_run = AsyncMock(return_value=True) - test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") + test_cases = ( + ("unshh", None), + ("unshhhhh", None), + ("unshhhhhhhhh", None), + ("unshh", MockTextChannel()) + ) - for case in test_cases: - with self.subTest(message=case): + for invoke, channel in test_cases: + with self.subTest(message=invoke, channel=channel): self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) self.ctx.reset_mock() - self.ctx.invoked_with = case + + self.ctx.invoked_with = invoke + self.ctx.message.content = f"!{invoke}" + if channel is not None: + self.ctx.message.content += f" {channel.name}" + self.ctx.guild.text_channels = [channel] + self.assertTrue(await self.cog.try_silence(self.ctx)) - self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) + self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=channel) + + async def test_try_silence_unsilence_message(self): + """If the words after the command could not be converted to a channel, None should be passed as channel.""" + self.silence.silence.can_run = AsyncMock(return_value=True) + self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) + + self.ctx.invoked_with = "unshh" + self.ctx.message.content = "!unshh not_a_channel" + + self.assertTrue(await self.cog.try_silence(self.ctx)) + self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=None) async def test_try_silence_no_match(self): """Should return `False` when message don't match.""" diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index fa5fc9e81..59a5893ef 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,15 +1,26 @@ import asyncio +import itertools import unittest from datetime import datetime, timezone +from typing import List, Tuple from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from async_rediscache import RedisSession from discord import PermissionOverwrite -from bot.constants import Channels, Guild, Roles +from bot.constants import Channels, Guild, MODERATION_ROLES, Roles from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import ( + MockBot, + MockContext, + MockGuild, + MockMember, + MockRole, + MockTextChannel, + MockVoiceChannel, + autospec +) redis_session = None redis_loop = asyncio.get_event_loop() @@ -149,7 +160,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): self.assertTrue(self.cog._init_task.cancelled()) @autospec("discord.ext.commands", "has_any_role") - @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) + @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3)) async def test_cog_check(self, role_check): """Role check was called with `MODERATION_ROLES`""" ctx = MockContext() @@ -159,6 +170,170 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase): role_check.assert_called_once_with(*(1, 2, 3)) role_check.return_value.predicate.assert_awaited_once_with(ctx) + async def test_force_voice_sync(self): + """Tests the _force_voice_sync helper function.""" + await self.cog._async_init() + + # Create a regular member, and one member for each of the moderation roles + moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] + members = [MockMember(), *moderation_members] + + channel = MockVoiceChannel(members=members) + + await self.cog._force_voice_sync(channel) + for member in members: + if member in moderation_members: + member.move_to.assert_not_called() + else: + self.assertEqual(member.move_to.call_count, 2) + calls = member.move_to.call_args_list + + # Tests that the member was moved to the afk channel, and back. + self.assertEqual((channel.guild.afk_channel,), calls[0].args) + self.assertEqual((channel,), calls[1].args) + + async def test_force_voice_sync_no_channel(self): + """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" + await self.cog._async_init() + + channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) + new_channel = MockVoiceChannel(delete=AsyncMock()) + channel.guild.create_voice_channel.return_value = new_channel + + await self.cog._force_voice_sync(channel) + + # Check channel creation + overwrites = { + channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) + } + channel.guild.create_voice_channel.assert_awaited_once_with("mute-temp", overwrites=overwrites) + + # Check bot deleted channel + new_channel.delete.assert_awaited_once() + + async def test_voice_kick(self): + """Test to ensure kick function can remove all members from a voice channel.""" + await self.cog._async_init() + + # Create a regular member, and one member for each of the moderation roles + moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] + members = [MockMember(), *moderation_members] + + channel = MockVoiceChannel(members=members) + await self.cog._kick_voice_members(channel) + + for member in members: + if member in moderation_members: + member.move_to.assert_not_called() + else: + self.assertEqual((None,), member.move_to.call_args_list[0].args) + + @staticmethod + def create_erroneous_members() -> Tuple[List[MockMember], List[MockMember]]: + """ + Helper method to generate a list of members that error out on move_to call. + + Returns the list of erroneous members, + as well as a list of regular and erroneous members combined, in that order. + """ + erroneous_member = MockMember(move_to=AsyncMock(side_effect=Exception())) + members = [MockMember(), erroneous_member] + + return erroneous_member, members + + async def test_kick_move_to_error(self): + """Test to ensure move_to gets called on all members during kick, even if some fail.""" + await self.cog._async_init() + _, members = self.create_erroneous_members() + + await self.cog._kick_voice_members(MockVoiceChannel(members=members)) + for member in members: + member.move_to.assert_awaited_once() + + async def test_sync_move_to_error(self): + """Test to ensure move_to gets called on all members during sync, even if some fail.""" + await self.cog._async_init() + failing_member, members = self.create_erroneous_members() + + await self.cog._force_voice_sync(MockVoiceChannel(members=members)) + for member in members: + self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) + + +class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): + """Tests for the silence argument parser utility function.""" + + def setUp(self): + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + self.cog._init_task = asyncio.Future() + self.cog._init_task.set_result(None) + + @autospec(silence.Silence, "send_message", pass_mocks=False) + @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) + @autospec(silence.Silence, "parse_silence_args") + async def test_command(self, parser_mock): + """Test that the command passes in the correct arguments for different calls.""" + test_cases = ( + (), + (15, ), + (MockTextChannel(),), + (MockTextChannel(), 15), + ) + + ctx = MockContext() + parser_mock.return_value = (ctx.channel, 10) + + for case in test_cases: + with self.subTest("Test command converters", args=case): + await self.cog.silence.callback(self.cog, ctx, *case) + + try: + first_arg = case[0] + except IndexError: + # Default value when the first argument is not passed + first_arg = None + + try: + second_arg = case[1] + except IndexError: + # Default value when the second argument is not passed + second_arg = 10 + + parser_mock.assert_called_with(ctx, first_arg, second_arg) + + async def test_no_arguments(self): + """Test the parser when no arguments are passed to the command.""" + ctx = MockContext() + channel, duration = self.cog.parse_silence_args(ctx, None, 10) + + self.assertEqual(ctx.channel, channel) + self.assertEqual(10, duration) + + async def test_channel_only(self): + """Test the parser when just the channel argument is passed.""" + expected_channel = MockTextChannel() + actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 10) + + self.assertEqual(expected_channel, actual_channel) + self.assertEqual(10, duration) + + async def test_duration_only(self): + """Test the parser when just the duration argument is passed.""" + ctx = MockContext() + channel, duration = self.cog.parse_silence_args(ctx, 15, 10) + + self.assertEqual(ctx.channel, channel) + self.assertEqual(15, duration) + + async def test_all_args(self): + """Test the parser when both channel and duration are passed.""" + expected_channel = MockTextChannel() + actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 15) + + self.assertEqual(expected_channel, actual_channel) + self.assertEqual(15, duration) + @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) class RescheduleTests(unittest.IsolatedAsyncioTestCase): @@ -235,6 +410,16 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase): self.cog.notifier.add_channel.assert_not_called() +def voice_sync_helper(function): + """Helper wrapper to test the sync and kick functions for voice channels.""" + @autospec(silence.Silence, "_force_voice_sync", "_kick_voice_members", "_set_silence_overwrites") + async def inner(self, sync, kick, overwrites): + overwrites.return_value = True + await function(self, MockContext(), sync, kick) + + return inner + + @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) class SilenceTests(unittest.IsolatedAsyncioTestCase): """Tests for the silence command and its related helper methods.""" @@ -242,7 +427,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): @autospec(silence.Silence, "_reschedule", pass_mocks=False) @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) def setUp(self) -> None: - self.bot = MockBot() + self.bot = MockBot(get_channel=lambda _: MockTextChannel()) self.cog = silence.Silence(self.bot) self.cog._init_task = asyncio.Future() self.cog._init_task.set_result(None) @@ -252,56 +437,127 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): asyncio.run(self.cog._async_init()) # Populate instance attributes. - self.channel = MockTextChannel() - self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) - self.channel.overwrites_for.return_value = self.overwrite + self.text_channel = MockTextChannel() + self.text_overwrite = PermissionOverwrite(send_messages=True, add_reactions=False) + self.text_channel.overwrites_for.return_value = self.text_overwrite + + self.voice_channel = MockVoiceChannel() + self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) + self.voice_channel.overwrites_for.return_value = self.voice_overwrite async def test_sent_correct_message(self): - """Appropriate failure/success message was sent by the command.""" + """Appropriate failure/success message was sent by the command to the correct channel.""" + # The following test tuples are made up of: + # duration, expected message, and the success of the _set_silence_overwrites function test_cases = ( (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), (None, silence.MSG_SILENCE_PERMANENT, True,), (5, silence.MSG_SILENCE_FAIL, False,), ) - for duration, message, was_silenced in test_cases: - ctx = MockContext() + + targets = (MockTextChannel(), MockVoiceChannel(), None) + + for (duration, message, was_silenced), target in itertools.product(test_cases, targets): with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): - with self.subTest(was_silenced=was_silenced, message=message, duration=duration): - await self.cog.silence.callback(self.cog, ctx, duration) - ctx.send.assert_called_once_with(message) + with self.subTest(was_silenced=was_silenced, target=target, message=message): + with mock.patch.object(self.cog, "send_message") as send_message: + ctx = MockContext() + await self.cog.silence.callback(self.cog, ctx, target, duration) + send_message.assert_called_once_with( + message, + ctx.channel, + target or ctx.channel, + alert_target=was_silenced + ) + + @voice_sync_helper + async def test_sync_called(self, ctx, sync, kick): + """Tests if silence command calls sync on a voice channel.""" + channel = MockVoiceChannel() + await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False) + + sync.assert_awaited_once_with(self.cog, channel) + kick.assert_not_called() + + @voice_sync_helper + async def test_kick_called(self, ctx, sync, kick): + """Tests if silence command calls kick on a voice channel.""" + channel = MockVoiceChannel() + await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True) + + kick.assert_awaited_once_with(channel) + sync.assert_not_called() + + @voice_sync_helper + async def test_sync_not_called(self, ctx, sync, kick): + """Tests that silence command does not call sync on a text channel.""" + channel = MockTextChannel() + await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False) + + sync.assert_not_called() + kick.assert_not_called() + + @voice_sync_helper + async def test_kick_not_called(self, ctx, sync, kick): + """Tests that silence command does not call kick on a text channel.""" + channel = MockTextChannel() + await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True) + + sync.assert_not_called() + kick.assert_not_called() async def test_skipped_already_silenced(self): """Permissions were not set and `False` was returned for an already silenced channel.""" subtests = ( - (False, PermissionOverwrite(send_messages=False, add_reactions=False)), - (True, PermissionOverwrite(send_messages=True, add_reactions=True)), - (True, PermissionOverwrite(send_messages=False, add_reactions=False)), + (False, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), + (True, MockTextChannel(), PermissionOverwrite(send_messages=True, add_reactions=True)), + (True, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), + (False, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)), + (True, MockVoiceChannel(), PermissionOverwrite(connect=True, speak=True)), + (True, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)), ) - for contains, overwrite in subtests: - with self.subTest(contains=contains, overwrite=overwrite): + for contains, channel, overwrite in subtests: + with self.subTest(contains=contains, is_text=isinstance(channel, MockTextChannel), overwrite=overwrite): self.cog.scheduler.__contains__.return_value = contains - channel = MockTextChannel() channel.overwrites_for.return_value = overwrite self.assertFalse(await self.cog._set_silence_overwrites(channel)) channel.set_permissions.assert_not_called() - async def test_silenced_channel(self): + async def test_silenced_text_channel(self): """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" - self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) - self.assertFalse(self.overwrite.send_messages) - self.assertFalse(self.overwrite.add_reactions) - self.channel.set_permissions.assert_awaited_once_with( + self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) + self.assertFalse(self.text_overwrite.send_messages) + self.assertFalse(self.text_overwrite.add_reactions) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite + overwrite=self.text_overwrite ) - async def test_preserved_other_overwrites(self): - """Channel's other unrelated overwrites were not changed.""" - prev_overwrite_dict = dict(self.overwrite) - await self.cog._set_silence_overwrites(self.channel) - new_overwrite_dict = dict(self.overwrite) + async def test_silenced_voice_channel_speak(self): + """Channel had `speak` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel)) + self.assertFalse(self.voice_overwrite.speak) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite + ) + + async def test_silenced_voice_channel_full(self): + """Channel had `speak` and `connect` permissions revoked for verified role.""" + self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, kick=True)) + self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite + ) + + async def test_preserved_other_overwrites_text(self): + """Channel's other unrelated overwrites were not changed for a text channel mute.""" + prev_overwrite_dict = dict(self.text_overwrite) + await self.cog._set_silence_overwrites(self.text_channel) + new_overwrite_dict = dict(self.text_overwrite) # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. del prev_overwrite_dict['send_messages'] @@ -311,6 +567,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + async def test_preserved_other_overwrites_voice(self): + """Channel's other unrelated overwrites were not changed for a voice channel mute.""" + prev_overwrite_dict = dict(self.voice_overwrite) + await self.cog._set_silence_overwrites(self.voice_channel) + new_overwrite_dict = dict(self.voice_overwrite) + + # Remove 'connect' & 'speak' keys because they were changed by the method. + del prev_overwrite_dict['connect'] + del prev_overwrite_dict['speak'] + del new_overwrite_dict['connect'] + del new_overwrite_dict['speak'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + async def test_temp_not_added_to_notifier(self): """Channel was not added to notifier if a duration was set for the silence.""" with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): @@ -320,7 +590,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_indefinite_added_to_notifier(self): """Channel was added to notifier if a duration was not set for the silence.""" with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): - await self.cog.silence.callback(self.cog, MockContext(), None) + await self.cog.silence.callback(self.cog, MockContext(), None, None) self.cog.notifier.add_channel.assert_called_once() async def test_silenced_not_added_to_notifier(self): @@ -332,8 +602,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_previous_overwrites(self): """Channel's previous overwrites were cached.""" overwrite_json = '{"send_messages": true, "add_reactions": false}' - await self.cog._set_silence_overwrites(self.channel) - self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + await self.cog._set_silence_overwrites(self.text_channel) + self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json) @autospec(silence, "datetime") async def test_cached_unsilence_time(self, datetime_mock): @@ -343,7 +613,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): timestamp = now_timestamp + duration * 60 datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) - ctx = MockContext(channel=self.channel) + ctx = MockContext(channel=self.text_channel) await self.cog.silence.callback(self.cog, ctx, duration) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) @@ -351,26 +621,33 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase): async def test_cached_indefinite_time(self): """A value of -1 was cached for a permanent silence.""" - ctx = MockContext(channel=self.channel) - await self.cog.silence.callback(self.cog, ctx, None) + ctx = MockContext(channel=self.text_channel) + await self.cog.silence.callback(self.cog, ctx, None, None) self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) async def test_scheduled_task(self): """An unsilence task was scheduled.""" - ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock()) await self.cog.silence.callback(self.cog, ctx, 5) args = (300, ctx.channel.id, ctx.invoke.return_value) self.cog.scheduler.schedule_later.assert_called_once_with(*args) - ctx.invoke.assert_called_once_with(self.cog.unsilence) + ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel) async def test_permanent_not_scheduled(self): """A task was not scheduled for a permanent silence.""" - ctx = MockContext(channel=self.channel) - await self.cog.silence.callback(self.cog, ctx, None) + ctx = MockContext(channel=self.text_channel) + await self.cog.silence.callback(self.cog, ctx, None, None) self.cog.scheduler.schedule_later.assert_not_called() + async def test_indefinite_silence(self): + """Test silencing a channel forever.""" + with mock.patch.object(self.cog, "_schedule_unsilence") as unsilence: + ctx = MockContext(channel=self.text_channel) + await self.cog.silence.callback(self.cog, ctx, -1) + unsilence.assert_awaited_once_with(ctx, ctx.channel, None) + @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) class UnsilenceTests(unittest.IsolatedAsyncioTestCase): @@ -391,9 +668,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): self.cog.scheduler.__contains__.return_value = True overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' - self.channel = MockTextChannel() - self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) - self.channel.overwrites_for.return_value = self.overwrite + self.text_channel = MockTextChannel() + self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) + self.text_channel.overwrites_for.return_value = self.text_overwrite + + self.voice_channel = MockVoiceChannel() + self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) + self.voice_channel.overwrites_for.return_value = self.voice_overwrite async def test_sent_correct_message(self): """Appropriate failure/success message was sent by the command.""" @@ -401,88 +682,128 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): test_cases = ( (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), - (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), + (False, silence.MSG_UNSILENCE_MANUAL, self.text_overwrite), (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), ) - for was_unsilenced, message, overwrite in test_cases: + + targets = (None, MockTextChannel()) + + for (was_unsilenced, message, overwrite), target in itertools.product(test_cases, targets): ctx = MockContext() - with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): - with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): - ctx.channel.overwrites_for.return_value = overwrite - await self.cog.unsilence.callback(self.cog, ctx) - ctx.channel.send.assert_called_once_with(message) + ctx.channel.overwrites_for.return_value = overwrite + if target: + target.overwrites_for.return_value = overwrite + + with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): + with mock.patch.object(self.cog, "send_message") as send_message: + with self.subTest(was_unsilenced=was_unsilenced, overwrite=overwrite, target=target): + await self.cog.unsilence.callback(self.cog, ctx, channel=target) + + call_args = (message, ctx.channel, target or ctx.channel) + send_message.assert_awaited_once_with(*call_args, alert_target=was_unsilenced) async def test_skipped_already_unsilenced(self): """Permissions were not set and `False` was returned for an already unsilenced channel.""" self.cog.scheduler.__contains__.return_value = False self.cog.previous_overwrites.get.return_value = None - channel = MockTextChannel() - self.assertFalse(await self.cog._unsilence(channel)) - channel.set_permissions.assert_not_called() + for channel in (MockVoiceChannel(), MockTextChannel()): + with self.subTest(channel=channel): + self.assertFalse(await self.cog._unsilence(channel)) + channel.set_permissions.assert_not_called() - async def test_restored_overwrites(self): - """Channel's `send_message` and `add_reactions` overwrites were restored.""" - await self.cog._unsilence(self.channel) - self.channel.set_permissions.assert_awaited_once_with( + async def test_restored_overwrites_text(self): + """Text channel's `send_message` and `add_reactions` overwrites were restored.""" + await self.cog._unsilence(self.text_channel) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite, + overwrite=self.text_overwrite, + ) + + # Recall that these values are determined by the fixture. + self.assertTrue(self.text_overwrite.send_messages) + self.assertFalse(self.text_overwrite.add_reactions) + + async def test_restored_overwrites_voice(self): + """Voice channel's `connect` and `speak` overwrites were restored.""" + await self.cog._unsilence(self.voice_channel) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite, ) # Recall that these values are determined by the fixture. - self.assertTrue(self.overwrite.send_messages) - self.assertFalse(self.overwrite.add_reactions) + self.assertTrue(self.voice_overwrite.connect) + self.assertTrue(self.voice_overwrite.speak) - async def test_cache_miss_used_default_overwrites(self): - """Both overwrites were set to None due previous values not being found in the cache.""" + async def test_cache_miss_used_default_overwrites_text(self): + """Text overwrites were set to None due previous values not being found in the cache.""" self.cog.previous_overwrites.get.return_value = None - await self.cog._unsilence(self.channel) - self.channel.set_permissions.assert_awaited_once_with( + await self.cog._unsilence(self.text_channel) + self.text_channel.set_permissions.assert_awaited_once_with( self.cog._everyone_role, - overwrite=self.overwrite, + overwrite=self.text_overwrite, + ) + + self.assertIsNone(self.text_overwrite.send_messages) + self.assertIsNone(self.text_overwrite.add_reactions) + + async def test_cache_miss_used_default_overwrites_voice(self): + """Voice overwrites were set to None due previous values not being found in the cache.""" + self.cog.previous_overwrites.get.return_value = None + + await self.cog._unsilence(self.voice_channel) + self.voice_channel.set_permissions.assert_awaited_once_with( + self.cog._verified_voice_role, + overwrite=self.voice_overwrite, ) - self.assertIsNone(self.overwrite.send_messages) - self.assertIsNone(self.overwrite.add_reactions) + self.assertIsNone(self.voice_overwrite.connect) + self.assertIsNone(self.voice_overwrite.speak) - async def test_cache_miss_sent_mod_alert(self): - """A message was sent to the mod alerts channel.""" + async def test_cache_miss_sent_mod_alert_text(self): + """A message was sent to the mod alerts channel upon muting a text channel.""" self.cog.previous_overwrites.get.return_value = None + await self.cog._unsilence(self.text_channel) + self.cog._mod_alerts_channel.send.assert_awaited_once() - await self.cog._unsilence(self.channel) + async def test_cache_miss_sent_mod_alert_voice(self): + """A message was sent to the mod alerts channel upon muting a voice channel.""" + self.cog.previous_overwrites.get.return_value = None + await self.cog._unsilence(MockVoiceChannel()) self.cog._mod_alerts_channel.send.assert_awaited_once() async def test_removed_notifier(self): """Channel was removed from `notifier`.""" - await self.cog._unsilence(self.channel) - self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + await self.cog._unsilence(self.text_channel) + self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel) async def test_deleted_cached_overwrite(self): """Channel was deleted from the overwrites cache.""" - await self.cog._unsilence(self.channel) - self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id) async def test_deleted_cached_time(self): """Channel was deleted from the timestamp cache.""" - await self.cog._unsilence(self.channel) - self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id) async def test_cancelled_task(self): """The scheduled unsilence task should be cancelled.""" - await self.cog._unsilence(self.channel) - self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + await self.cog._unsilence(self.text_channel) + self.cog.scheduler.cancel.assert_called_once_with(self.text_channel.id) - async def test_preserved_other_overwrites(self): - """Channel's other unrelated overwrites were not changed, including cache misses.""" + async def test_preserved_other_overwrites_text(self): + """Text channel's other unrelated overwrites were not changed, including cache misses.""" for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): with self.subTest(overwrite_json=overwrite_json): self.cog.previous_overwrites.get.return_value = overwrite_json - prev_overwrite_dict = dict(self.overwrite) - await self.cog._unsilence(self.channel) - new_overwrite_dict = dict(self.overwrite) + prev_overwrite_dict = dict(self.text_overwrite) + await self.cog._unsilence(self.text_channel) + new_overwrite_dict = dict(self.text_overwrite) # Remove these keys because they were modified by the unsilence. del prev_overwrite_dict['send_messages'] @@ -491,3 +812,114 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase): del new_overwrite_dict['add_reactions'] self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_preserved_other_overwrites_voice(self): + """Voice channel's other unrelated overwrites were not changed, including cache misses.""" + for overwrite_json in ('{"connect": true, "speak": true}', None): + with self.subTest(overwrite_json=overwrite_json): + self.cog.previous_overwrites.get.return_value = overwrite_json + + prev_overwrite_dict = dict(self.voice_overwrite) + await self.cog._unsilence(self.voice_channel) + new_overwrite_dict = dict(self.voice_overwrite) + + # Remove these keys because they were modified by the unsilence. + del prev_overwrite_dict['connect'] + del prev_overwrite_dict['speak'] + del new_overwrite_dict['connect'] + del new_overwrite_dict['speak'] + + self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + + async def test_unsilence_role(self): + """Tests unsilence_wrapper applies permission to the correct role.""" + test_cases = ( + (MockTextChannel(), self.cog.bot.get_guild(Guild.id).default_role), + (MockVoiceChannel(), self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified)) + ) + + for channel, role in test_cases: + with self.subTest(channel=channel, role=role): + await self.cog._unsilence_wrapper(channel, MockContext()) + channel.overwrites_for.assert_called_with(role) + + +class SendMessageTests(unittest.IsolatedAsyncioTestCase): + """Unittests for the send message helper function.""" + + def setUp(self) -> None: + self.bot = MockBot() + self.cog = silence.Silence(self.bot) + + self.text_channels = [MockTextChannel() for _ in range(2)] + self.bot.get_channel.return_value = self.text_channels[1] + + self.voice_channel = MockVoiceChannel() + + async def test_send_to_channel(self): + """Tests a basic case for the send function.""" + message = "Test basic message." + await self.cog.send_message(message, *self.text_channels, alert_target=False) + + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_not_called() + + async def test_send_to_multiple_channels(self): + """Tests sending messages to two channels.""" + message = "Test basic message." + await self.cog.send_message(message, *self.text_channels, alert_target=True) + + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_awaited_once_with(message) + + async def test_duration_replacement(self): + """Tests that the channel name was set correctly for one target channel.""" + message = "Current. The following should be replaced: {channel}." + await self.cog.send_message(message, *self.text_channels, alert_target=False) + + updated_message = message.format(channel=self.text_channels[0].mention) + self.text_channels[0].send.assert_awaited_once_with(updated_message) + self.text_channels[1].send.assert_not_called() + + async def test_name_replacement_multiple_channels(self): + """Tests that the channel name was set correctly for two channels.""" + message = "Current. The following should be replaced: {channel}." + await self.cog.send_message(message, *self.text_channels, alert_target=True) + + self.text_channels[0].send.assert_awaited_once_with(message.format(channel=self.text_channels[0].mention)) + self.text_channels[1].send.assert_awaited_once_with(message.format(channel="current channel")) + + async def test_silence_voice(self): + """Tests that the correct message was sent when a voice channel is muted without alerting.""" + message = "This should show up just here." + await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=False) + self.text_channels[0].send.assert_awaited_once_with(message) + self.text_channels[1].send.assert_not_called() + + async def test_silence_voice_alert(self): + """Tests that the correct message was sent when a voice channel is muted with alerts.""" + with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: + mock_voice_channels.get.return_value = self.text_channels[1].id + + message = "This should show up as {channel}." + await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=True) + + updated_message = message.format(channel=self.voice_channel.mention) + self.text_channels[0].send.assert_awaited_once_with(updated_message) + self.text_channels[1].send.assert_awaited_once_with(updated_message) + + mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + + async def test_silence_voice_sibling_channel(self): + """Tests silencing a voice channel from the related text channel.""" + with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: + mock_voice_channels.get.return_value = self.text_channels[1].id + + message = "This should show up as {channel}." + await self.cog.send_message(message, self.text_channels[1], self.voice_channel, alert_target=True) + + updated_message = message.format(channel=self.voice_channel.mention) + self.text_channels[1].send.assert_awaited_once_with(updated_message) + + mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + self.bot.get_channel.assert_called_once_with(self.text_channels[1].id) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index 4af84dde5..2a1c4e543 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -291,7 +291,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase): ("10", 10), ("5m", 5), ("5M", 5), - ("forever", None), + ("forever", -1), ) converter = HushDurationConverter() for minutes_string, expected_minutes in test_values: diff --git a/tests/helpers.py b/tests/helpers.py index eedd7a601..3978076ed 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient from bot.bot import Bot from tests._autospec import autospec # noqa: F401 other modules import it via this module - for logger in logging.Logger.manager.loggerDict.values(): # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = { } state = unittest.mock.MagicMock() guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data) class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ - spec_set = channel_instance + spec_set = text_channel_instance + + def __init__(self, **kwargs) -> None: + default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): + """ + A MagicMock subclass to mock VoiceChannel objects. + + Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For + more information, see the `MockGuild` docstring. + """ + spec_set = voice_channel_instance def __init__(self, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} |