diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/README.md | 31 | ||||
| -rw-r--r-- | tests/bot/exts/backend/test_error_handler.py | 88 | ||||
| -rw-r--r-- | tests/bot/exts/events/__init__.py | 0 | ||||
| -rw-r--r-- | tests/bot/exts/events/test_code_jams.py | 170 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_help.py | 23 | ||||
| -rw-r--r-- | tests/bot/exts/info/test_information.py | 13 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_infractions.py | 6 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/infraction/test_utils.py | 10 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_modlog.py | 2 | ||||
| -rw-r--r-- | tests/bot/exts/moderation/test_silence.py | 600 | ||||
| -rw-r--r-- | tests/bot/exts/utils/test_jams.py | 171 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 26 | ||||
| -rw-r--r-- | tests/bot/test_converters.py | 27 | ||||
| -rw-r--r-- | tests/bot/utils/test_message_cache.py | 214 | ||||
| -rw-r--r-- | tests/bot/utils/test_time.py | 73 | ||||
| -rw-r--r-- | tests/helpers.py | 47 | 
16 files changed, 1141 insertions, 360 deletions
| diff --git a/tests/README.md b/tests/README.md index 0192f916e..b7fddfaa2 100644 --- a/tests/README.md +++ b/tests/README.md @@ -4,6 +4,14 @@ Our bot is one of the most important tools we have for running our community. As  _**Note:** This is a practical guide to getting started with writing tests for our bot, not a general introduction to writing unit tests in Python. If you're looking for a more general introduction, you can take a look at the [Additional resources](#additional-resources) section at the bottom of this page._ +### Table of contents: +- [Tools](#tools) +- [Running tests](#running-tests)   +- [Writing tests](#writing-tests) +- [Mocking](#mocking) +- [Some considerations](#some-considerations) +- [Additional resources](#additional-resources) +  ## Tools  We are using the following modules and packages for our unit tests: @@ -25,6 +33,29 @@ To ensure the results you obtain on your personal machine are comparable to thos  If you want a coverage report, make sure to run the tests with `poetry run task test` *first*. +## Running tests +There are multiple ways to run the tests, which one you use will be determined by your goal, and stage in development. + +When actively developing, you'll most likely be working on one portion of the codebase, and as a result, won't need to run the entire test suite. +To run just one file, and save time, you can use the following command: +```shell +poetry run task test-nocov <path/to/file.py> +``` + +For example: +```shell +poetry run task test-nocov tests/bot/exts/test_cogs.py +``` +will run the test suite in the `test_cogs` file. + +If you'd like to collect coverage as well, you can append `--cov` to the command above. + + +If you're done and are preparing to commit and push your code, it's a good idea to run the entire test suite as a sanity check: +```shell +poetry run task test +``` +  ## Writing tests  Since consistency is an important consideration for collaborative projects, we have written some guidelines on writing tests for the bot. In addition to these guidelines, it's a good idea to look at the existing code base for examples (e.g., [`test_converters.py`](/tests/bot/test_converters.py)). diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index eafcbae6c..ce59ee5fa 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -4,12 +4,12 @@ from unittest.mock import AsyncMock, MagicMock, call, patch  from discord.ext.commands import errors  from bot.api import ResponseCodeError -from bot.errors import InvalidInfractedUser, LockedResourceError +from bot.errors import InvalidInfractedUserError, LockedResourceError  from bot.exts.backend.error_handler import ErrorHandler, setup  from bot.exts.info.tags import Tags  from bot.exts.moderation.silence import Silence  from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole +from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel  class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -130,7 +130,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):                  "expect_mock_call": "send"              },              { -                "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))), +                "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUserError(self.ctx.author))),                  "expect_mock_call": "send"              }          ) @@ -226,8 +226,8 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):          self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError())          self.assertFalse(await self.cog.try_silence(self.ctx)) -    async def test_try_silence_silencing(self): -        """Should run silence command with correct arguments.""" +    async def test_try_silence_silence_duration(self): +        """Should run silence command with correct duration argument."""          self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)          test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") @@ -238,21 +238,85 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):                  self.assertTrue(await self.cog.try_silence(self.ctx))                  self.ctx.invoke.assert_awaited_once_with(                      self.bot.get_command.return_value, -                    duration=min(case.count("h")*2, 15) +                    duration_or_channel=None, +                    duration=min(case.count("h")*2, 15), +                    kick=False                  ) +    async def test_try_silence_silence_arguments(self): +        """Should run silence with the correct channel, duration, and kick arguments.""" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) + +        test_cases = ( +            (MockTextChannel(), None),  # None represents the case when no argument is passed +            (MockTextChannel(), False), +            (MockTextChannel(), True) +        ) + +        for channel, kick in test_cases: +            with self.subTest(kick=kick, channel=channel): +                self.ctx.reset_mock() +                self.ctx.invoked_with = "shh" + +                self.ctx.message.content = f"!shh {channel.name} {kick if kick is not None else ''}" +                self.ctx.guild.text_channels = [channel] + +                self.assertTrue(await self.cog.try_silence(self.ctx)) +                self.ctx.invoke.assert_awaited_once_with( +                    self.bot.get_command.return_value, +                    duration_or_channel=channel, +                    duration=4, +                    kick=(kick if kick is not None else False) +                ) + +    async def test_try_silence_silence_message(self): +        """If the words after the command could not be converted to a channel, None should be passed as channel.""" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) +        self.ctx.invoked_with = "shh" +        self.ctx.message.content = "!shh not_a_channel true" + +        self.assertTrue(await self.cog.try_silence(self.ctx)) +        self.ctx.invoke.assert_awaited_once_with( +            self.bot.get_command.return_value, +            duration_or_channel=None, +            duration=4, +            kick=False +        ) +      async def test_try_silence_unsilence(self): -        """Should call unsilence command.""" +        """Should call unsilence command with correct duration and channel arguments."""          self.silence.silence.can_run = AsyncMock(return_value=True) -        test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") +        test_cases = ( +            ("unshh", None), +            ("unshhhhh", None), +            ("unshhhhhhhhh", None), +            ("unshh", MockTextChannel()) +        ) -        for case in test_cases: -            with self.subTest(message=case): +        for invoke, channel in test_cases: +            with self.subTest(message=invoke, channel=channel):                  self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)                  self.ctx.reset_mock() -                self.ctx.invoked_with = case + +                self.ctx.invoked_with = invoke +                self.ctx.message.content = f"!{invoke}" +                if channel is not None: +                    self.ctx.message.content += f" {channel.name}" +                    self.ctx.guild.text_channels = [channel] +                  self.assertTrue(await self.cog.try_silence(self.ctx)) -                self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) +                self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=channel) + +    async def test_try_silence_unsilence_message(self): +        """If the words after the command could not be converted to a channel, None should be passed as channel.""" +        self.silence.silence.can_run = AsyncMock(return_value=True) +        self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) + +        self.ctx.invoked_with = "unshh" +        self.ctx.message.content = "!unshh not_a_channel" + +        self.assertTrue(await self.cog.try_silence(self.ctx)) +        self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence, channel=None)      async def test_try_silence_no_match(self):          """Should return `False` when message don't match.""" diff --git a/tests/bot/exts/events/__init__.py b/tests/bot/exts/events/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/events/__init__.py diff --git a/tests/bot/exts/events/test_code_jams.py b/tests/bot/exts/events/test_code_jams.py new file mode 100644 index 000000000..b9ee1e363 --- /dev/null +++ b/tests/bot/exts/events/test_code_jams.py @@ -0,0 +1,170 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch + +from discord import CategoryChannel +from discord.ext.commands import BadArgument + +from bot.constants import Roles +from bot.exts.events import code_jams +from bot.exts.events.code_jams import _channels, _cog +from tests.helpers import ( +    MockAttachment, MockBot, MockCategoryChannel, MockContext, +    MockGuild, MockMember, MockRole, MockTextChannel, autospec +) + +TEST_CSV = b"""\ +Team Name,Team Member Discord ID,Team Leader +Annoyed Alligators,12345,Y +Annoyed Alligators,54321,N +Oscillating Otters,12358,Y +Oscillating Otters,74832,N +Oscillating Otters,19903,N +Annoyed Alligators,11111,N +""" + + +def get_mock_category(channel_count: int, name: str) -> CategoryChannel: +    """Return a mocked code jam category.""" +    category = create_autospec(CategoryChannel, spec_set=True, instance=True) +    category.name = name +    category.channels = [MockTextChannel() for _ in range(channel_count)] + +    return category + + +class JamCodejamCreateTests(unittest.IsolatedAsyncioTestCase): +    """Tests for `codejam create` command.""" + +    def setUp(self): +        self.bot = MockBot() +        self.admin_role = MockRole(name="Admins", id=Roles.admins) +        self.command_user = MockMember([self.admin_role]) +        self.guild = MockGuild([self.admin_role]) +        self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) +        self.cog = _cog.CodeJams(self.bot) + +    async def test_message_without_attachments(self): +        """If no link or attachments are provided, commands.BadArgument should be raised.""" +        self.ctx.message.attachments = [] + +        with self.assertRaises(BadArgument): +            await self.cog.create(self.cog, self.ctx, None) + +    @patch.object(_channels, "create_team_channel") +    @patch.object(_channels, "create_team_leader_channel") +    async def test_result_sending(self, create_leader_channel, create_team_channel): +        """Should call `ctx.send` when everything goes right.""" +        self.ctx.message.attachments = [MockAttachment()] +        self.ctx.message.attachments[0].read = AsyncMock() +        self.ctx.message.attachments[0].read.return_value = TEST_CSV + +        team_leaders = MockRole() + +        self.guild.get_member.return_value = MockMember() + +        self.ctx.guild.create_role = AsyncMock() +        self.ctx.guild.create_role.return_value = team_leaders +        self.cog.add_roles = AsyncMock() + +        await self.cog.create(self.cog, self.ctx, None) + +        create_team_channel.assert_awaited() +        create_leader_channel.assert_awaited_once_with( +            self.ctx.guild, team_leaders +        ) +        self.ctx.send.assert_awaited_once() + +    async def test_link_returning_non_200_status(self): +        """When the URL passed returns a non 200 status, it should send a message informing them.""" +        self.bot.http_session.get.return_value = mock = MagicMock() +        mock.status = 404 +        await self.cog.create(self.cog, self.ctx, "https://not-a-real-link.com") + +        self.ctx.send.assert_awaited_once() + +    @patch.object(_channels, "_send_status_update") +    async def test_category_doesnt_exist(self, update): +        """Should create a new code jam category.""" +        subtests = ( +            [], +            [get_mock_category(_channels.MAX_CHANNELS, _channels.CATEGORY_NAME)], +            [get_mock_category(_channels.MAX_CHANNELS - 2, "other")], +        ) + +        for categories in subtests: +            update.reset_mock() +            self.guild.reset_mock() +            self.guild.categories = categories + +            with self.subTest(categories=categories): +                actual_category = await _channels._get_category(self.guild) + +                update.assert_called_once() +                self.guild.create_category_channel.assert_awaited_once() +                category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] + +                self.assertFalse(category_overwrites[self.guild.default_role].read_messages) +                self.assertTrue(category_overwrites[self.guild.me].read_messages) +                self.assertEqual(self.guild.create_category_channel.return_value, actual_category) + +    async def test_category_channel_exist(self): +        """Should not try to create category channel.""" +        expected_category = get_mock_category(_channels.MAX_CHANNELS - 2, _channels.CATEGORY_NAME) +        self.guild.categories = [ +            get_mock_category(_channels.MAX_CHANNELS - 2, "other"), +            expected_category, +            get_mock_category(0, _channels.CATEGORY_NAME), +        ] + +        actual_category = await _channels._get_category(self.guild) +        self.assertEqual(expected_category, actual_category) + +    async def test_channel_overwrites(self): +        """Should have correct permission overwrites for users and roles.""" +        leader = (MockMember(), True) +        members = [leader] + [(MockMember(), False) for _ in range(4)] +        overwrites = _channels._get_overwrites(members, self.guild) + +        for member, _ in members: +            self.assertTrue(overwrites[member].read_messages) + +    @patch.object(_channels, "_get_overwrites") +    @patch.object(_channels, "_get_category") +    @autospec(_channels, "_add_team_leader_roles", pass_mocks=False) +    async def test_team_channels_creation(self, get_category, get_overwrites): +        """Should create a text channel for a team.""" +        team_leaders = MockRole() +        members = [(MockMember(), True)] + [(MockMember(), False) for _ in range(5)] +        category = MockCategoryChannel() +        category.create_text_channel = AsyncMock() + +        get_category.return_value = category +        await _channels.create_team_channel(self.guild, "my-team", members, team_leaders) + +        category.create_text_channel.assert_awaited_once_with( +            "my-team", +            overwrites=get_overwrites.return_value +        ) + +    async def test_jam_roles_adding(self): +        """Should add team leader role to leader and jam role to every team member.""" +        leader_role = MockRole(name="Team Leader") + +        leader = MockMember() +        members = [(leader, True)] + [(MockMember(), False) for _ in range(4)] +        await _channels._add_team_leader_roles(members, leader_role) + +        leader.add_roles.assert_awaited_once_with(leader_role) +        for member, is_leader in members: +            if not is_leader: +                member.add_roles.assert_not_awaited() + + +class CodeJamSetup(unittest.TestCase): +    """Test for `setup` function of `CodeJam` cog.""" + +    def test_setup(self): +        """Should call `bot.add_cog`.""" +        bot = MockBot() +        code_jams.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/info/test_help.py b/tests/bot/exts/info/test_help.py new file mode 100644 index 000000000..604c69671 --- /dev/null +++ b/tests/bot/exts/info/test_help.py @@ -0,0 +1,23 @@ +import unittest + +import rapidfuzz + +from bot.exts.info import help +from tests.helpers import MockBot, MockContext, autospec + + +class HelpCogTests(unittest.IsolatedAsyncioTestCase): +    def setUp(self) -> None: +        """Attach an instance of the cog to the class for tests.""" +        self.bot = MockBot() +        self.cog = help.Help(self.bot) +        self.ctx = MockContext(bot=self.bot) +        self.bot.help_command.context = self.ctx + +    @autospec(help.CustomHelpCommand, "get_all_help_choices", return_value={"help"}, pass_mocks=False) +    async def test_help_fuzzy_matching(self): +        """Test fuzzy matching of commands when called from help.""" +        result = await self.bot.help_command.command_not_found("holp") + +        match = {"help": rapidfuzz.fuzz.ratio("help", "holp")} +        self.assertEqual(match, result.possible_matches) diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 770660fe3..d8250befb 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -262,7 +262,6 @@ class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):          await self._method_subtests(self.cog.user_nomination_counts, test_values, header) [email protected]("bot.exts.info.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))  @unittest.mock.patch("bot.exts.info.information.constants.MODERATION_CHANNELS", new=[50])  class UserEmbedTests(unittest.IsolatedAsyncioTestCase):      """Tests for the creation of the `!user` embed.""" @@ -347,7 +346,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(              textwrap.dedent(f""" -                Created: {"1 year ago"} +                Created: {"<t:1:R>"}                  Profile: {user.mention}                  ID: {user.id}              """).strip(), @@ -356,7 +355,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(              textwrap.dedent(f""" -                Joined: {"1 year ago"} +                Joined: {"<t:1:R>"}                  Verified: {"True"}                  Roles: &Moderators              """).strip(), @@ -379,7 +378,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(              textwrap.dedent(f""" -                Created: {"1 year ago"} +                Created: {"<t:1:R>"}                  Profile: {user.mention}                  ID: {user.id}              """).strip(), @@ -388,7 +387,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(              textwrap.dedent(f""" -                Joined: {"1 year ago"} +                Joined: {"<t:1:R>"}                  Roles: &Moderators              """).strip(),              embed.fields[1].value @@ -508,7 +507,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):      @unittest.mock.patch("bot.exts.info.information.Information.create_user_embed")      async def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):          """Staff members should be able to bypass the bot-commands channel restriction.""" -        constants.STAFF_ROLES = [self.moderator_role.id] +        constants.STAFF_PARTNERS_COMMUNITY_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))          await self.cog.user_info(self.cog, ctx) @@ -520,7 +519,7 @@ class UserCommandTests(unittest.IsolatedAsyncioTestCase):      async def test_moderators_can_target_another_member(self, create_embed, constants):          """A moderator should be able to use `!user` targeting another user."""          constants.MODERATION_ROLES = [self.moderator_role.id] -        constants.STAFF_ROLES = [self.moderator_role.id] +        constants.STAFF_PARTNERS_COMMUNITY_ROLES = [self.moderator_role.id]          ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=50))          await self.cog.user_info(self.cog, ctx, self.target) diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index b9d527770..f844a9181 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -195,7 +195,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):      async def test_voice_unban_user_not_found(self):          """Should include info to return dict when user was not found from guild."""          self.guild.get_member.return_value = None -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {"Info": "User was not found in the guild."})      @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") @@ -206,7 +206,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = True          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "Sent" @@ -221,7 +221,7 @@ class VoiceBanTests(unittest.IsolatedAsyncioTestCase):          notify_pardon_mock.return_value = False          format_user_mock.return_value = "my-user" -        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild)          self.assertEqual(result, {              "Member": "my-user",              "DM": "**Failed**" diff --git a/tests/bot/exts/moderation/infraction/test_utils.py b/tests/bot/exts/moderation/infraction/test_utils.py index 50a717bb5..eb256f1fd 100644 --- a/tests/bot/exts/moderation/infraction/test_utils.py +++ b/tests/bot/exts/moderation/infraction/test_utils.py @@ -94,8 +94,8 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):          test_case = namedtuple("test_case", ["get_return_value", "expected_output", "infraction_nr", "send_msg"])          test_cases = [              test_case([], None, None, True), -            test_case([{"id": 123987}], {"id": 123987}, "123987", False), -            test_case([{"id": 123987}], {"id": 123987}, "123987", True) +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", False), +            test_case([{"id": 123987, "type": "ban"}], {"id": 123987, "type": "ban"}, "123987", True)          ]          for case in test_cases: @@ -137,7 +137,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):                      title=utils.INFRACTION_TITLE,                      description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(                          type="Ban", -                        expires="2020-02-26 09:20 (23 hours and 59 minutes) UTC", +                        expires="2020-02-26 09:20 (23 hours and 59 minutes)",                          reason="No reason provided."                      ),                      colour=Colours.soft_red, @@ -193,7 +193,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):                      title=utils.INFRACTION_TITLE,                      description=utils.INFRACTION_DESCRIPTION_TEMPLATE.format(                          type="Mute", -                        expires="2020-02-26 09:20 (23 hours and 59 minutes) UTC", +                        expires="2020-02-26 09:20 (23 hours and 59 minutes)",                          reason="Test"                      ),                      colour=Colours.soft_red, @@ -213,7 +213,7 @@ class ModerationUtilsTests(unittest.IsolatedAsyncioTestCase):                          type="Mute",                          expires="N/A",                          reason="foo bar" * 4000 -                    )[:2045] + "...", +                    )[:4093] + "...",                      colour=Colours.soft_red,                      url=utils.RULES_URL                  ).set_author( diff --git a/tests/bot/exts/moderation/test_modlog.py b/tests/bot/exts/moderation/test_modlog.py index f8f142484..79e04837d 100644 --- a/tests/bot/exts/moderation/test_modlog.py +++ b/tests/bot/exts/moderation/test_modlog.py @@ -25,5 +25,5 @@ class ModLogTests(unittest.IsolatedAsyncioTestCase):          )          embed = self.channel.send.call_args[1]["embed"]          self.assertEqual( -            embed.description, ("foo bar" * 3000)[:2045] + "..." +            embed.description, ("foo bar" * 3000)[:4093] + "..."          ) diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index fa5fc9e81..59a5893ef 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,15 +1,26 @@  import asyncio +import itertools  import unittest  from datetime import datetime, timezone +from typing import List, Tuple  from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock  from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Guild, Roles +from bot.constants import Channels, Guild, MODERATION_ROLES, Roles  from bot.exts.moderation import silence -from tests.helpers import MockBot, MockContext, MockTextChannel, autospec +from tests.helpers import ( +    MockBot, +    MockContext, +    MockGuild, +    MockMember, +    MockRole, +    MockTextChannel, +    MockVoiceChannel, +    autospec +)  redis_session = None  redis_loop = asyncio.get_event_loop() @@ -149,7 +160,7 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):          self.assertTrue(self.cog._init_task.cancelled())      @autospec("discord.ext.commands", "has_any_role") -    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    @mock.patch.object(silence.constants, "MODERATION_ROLES", new=(1, 2, 3))      async def test_cog_check(self, role_check):          """Role check was called with `MODERATION_ROLES`"""          ctx = MockContext() @@ -159,6 +170,170 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):          role_check.assert_called_once_with(*(1, 2, 3))          role_check.return_value.predicate.assert_awaited_once_with(ctx) +    async def test_force_voice_sync(self): +        """Tests the _force_voice_sync helper function.""" +        await self.cog._async_init() + +        # Create a regular member, and one member for each of the moderation roles +        moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] +        members = [MockMember(), *moderation_members] + +        channel = MockVoiceChannel(members=members) + +        await self.cog._force_voice_sync(channel) +        for member in members: +            if member in moderation_members: +                member.move_to.assert_not_called() +            else: +                self.assertEqual(member.move_to.call_count, 2) +                calls = member.move_to.call_args_list + +                # Tests that the member was moved to the afk channel, and back. +                self.assertEqual((channel.guild.afk_channel,), calls[0].args) +                self.assertEqual((channel,), calls[1].args) + +    async def test_force_voice_sync_no_channel(self): +        """Test to ensure _force_voice_sync can create its own voice channel if one is not available.""" +        await self.cog._async_init() + +        channel = MockVoiceChannel(guild=MockGuild(afk_channel=None)) +        new_channel = MockVoiceChannel(delete=AsyncMock()) +        channel.guild.create_voice_channel.return_value = new_channel + +        await self.cog._force_voice_sync(channel) + +        # Check channel creation +        overwrites = { +            channel.guild.default_role: PermissionOverwrite(speak=False, connect=False, view_channel=False) +        } +        channel.guild.create_voice_channel.assert_awaited_once_with("mute-temp", overwrites=overwrites) + +        # Check bot deleted channel +        new_channel.delete.assert_awaited_once() + +    async def test_voice_kick(self): +        """Test to ensure kick function can remove all members from a voice channel.""" +        await self.cog._async_init() + +        # Create a regular member, and one member for each of the moderation roles +        moderation_members = [MockMember(roles=[MockRole(id=role)]) for role in MODERATION_ROLES] +        members = [MockMember(), *moderation_members] + +        channel = MockVoiceChannel(members=members) +        await self.cog._kick_voice_members(channel) + +        for member in members: +            if member in moderation_members: +                member.move_to.assert_not_called() +            else: +                self.assertEqual((None,), member.move_to.call_args_list[0].args) + +    @staticmethod +    def create_erroneous_members() -> Tuple[List[MockMember], List[MockMember]]: +        """ +        Helper method to generate a list of members that error out on move_to call. + +        Returns the list of erroneous members, +        as well as a list of regular and erroneous members combined, in that order. +        """ +        erroneous_member = MockMember(move_to=AsyncMock(side_effect=Exception())) +        members = [MockMember(), erroneous_member] + +        return erroneous_member, members + +    async def test_kick_move_to_error(self): +        """Test to ensure move_to gets called on all members during kick, even if some fail.""" +        await self.cog._async_init() +        _, members = self.create_erroneous_members() + +        await self.cog._kick_voice_members(MockVoiceChannel(members=members)) +        for member in members: +            member.move_to.assert_awaited_once() + +    async def test_sync_move_to_error(self): +        """Test to ensure move_to gets called on all members during sync, even if some fail.""" +        await self.cog._async_init() +        failing_member, members = self.create_erroneous_members() + +        await self.cog._force_voice_sync(MockVoiceChannel(members=members)) +        for member in members: +            self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2) + + +class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the silence argument parser utility function.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +    @autospec(silence.Silence, "send_message", pass_mocks=False) +    @autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False) +    @autospec(silence.Silence, "parse_silence_args") +    async def test_command(self, parser_mock): +        """Test that the command passes in the correct arguments for different calls.""" +        test_cases = ( +            (), +            (15, ), +            (MockTextChannel(),), +            (MockTextChannel(), 15), +        ) + +        ctx = MockContext() +        parser_mock.return_value = (ctx.channel, 10) + +        for case in test_cases: +            with self.subTest("Test command converters", args=case): +                await self.cog.silence.callback(self.cog, ctx, *case) + +                try: +                    first_arg = case[0] +                except IndexError: +                    # Default value when the first argument is not passed +                    first_arg = None + +                try: +                    second_arg = case[1] +                except IndexError: +                    # Default value when the second argument is not passed +                    second_arg = 10 + +                parser_mock.assert_called_with(ctx, first_arg, second_arg) + +    async def test_no_arguments(self): +        """Test the parser when no arguments are passed to the command.""" +        ctx = MockContext() +        channel, duration = self.cog.parse_silence_args(ctx, None, 10) + +        self.assertEqual(ctx.channel, channel) +        self.assertEqual(10, duration) + +    async def test_channel_only(self): +        """Test the parser when just the channel argument is passed.""" +        expected_channel = MockTextChannel() +        actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 10) + +        self.assertEqual(expected_channel, actual_channel) +        self.assertEqual(10, duration) + +    async def test_duration_only(self): +        """Test the parser when just the duration argument is passed.""" +        ctx = MockContext() +        channel, duration = self.cog.parse_silence_args(ctx, 15, 10) + +        self.assertEqual(ctx.channel, channel) +        self.assertEqual(15, duration) + +    async def test_all_args(self): +        """Test the parser when both channel and duration are passed.""" +        expected_channel = MockTextChannel() +        actual_channel, duration = self.cog.parse_silence_args(MockContext(), expected_channel, 15) + +        self.assertEqual(expected_channel, actual_channel) +        self.assertEqual(15, duration) +  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class RescheduleTests(unittest.IsolatedAsyncioTestCase): @@ -235,6 +410,16 @@ class RescheduleTests(unittest.IsolatedAsyncioTestCase):          self.cog.notifier.add_channel.assert_not_called() +def voice_sync_helper(function): +    """Helper wrapper to test the sync and kick functions for voice channels.""" +    @autospec(silence.Silence, "_force_voice_sync", "_kick_voice_members", "_set_silence_overwrites") +    async def inner(self, sync, kick, overwrites): +        overwrites.return_value = True +        await function(self, MockContext(), sync, kick) + +    return inner + +  @autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)  class SilenceTests(unittest.IsolatedAsyncioTestCase):      """Tests for the silence command and its related helper methods.""" @@ -242,7 +427,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      @autospec(silence.Silence, "_reschedule", pass_mocks=False)      @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)      def setUp(self) -> None: -        self.bot = MockBot() +        self.bot = MockBot(get_channel=lambda _: MockTextChannel())          self.cog = silence.Silence(self.bot)          self.cog._init_task = asyncio.Future()          self.cog._init_task.set_result(None) @@ -252,56 +437,127 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          asyncio.run(self.cog._async_init())  # Populate instance attributes. -        self.channel = MockTextChannel() -        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) -        self.channel.overwrites_for.return_value = self.overwrite +        self.text_channel = MockTextChannel() +        self.text_overwrite = PermissionOverwrite(send_messages=True, add_reactions=False) +        self.text_channel.overwrites_for.return_value = self.text_overwrite + +        self.voice_channel = MockVoiceChannel() +        self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) +        self.voice_channel.overwrites_for.return_value = self.voice_overwrite      async def test_sent_correct_message(self): -        """Appropriate failure/success message was sent by the command.""" +        """Appropriate failure/success message was sent by the command to the correct channel.""" +        # The following test tuples are made up of: +        # duration, expected message, and the success of the _set_silence_overwrites function          test_cases = (              (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),              (None, silence.MSG_SILENCE_PERMANENT, True,),              (5, silence.MSG_SILENCE_FAIL, False,),          ) -        for duration, message, was_silenced in test_cases: -            ctx = MockContext() + +        targets = (MockTextChannel(), MockVoiceChannel(), None) + +        for (duration, message, was_silenced), target in itertools.product(test_cases, targets):              with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): -                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): -                    await self.cog.silence.callback(self.cog, ctx, duration) -                    ctx.send.assert_called_once_with(message) +                with self.subTest(was_silenced=was_silenced, target=target, message=message): +                    with mock.patch.object(self.cog, "send_message") as send_message: +                        ctx = MockContext() +                        await self.cog.silence.callback(self.cog, ctx, target, duration) +                        send_message.assert_called_once_with( +                            message, +                            ctx.channel, +                            target or ctx.channel, +                            alert_target=was_silenced +                        ) + +    @voice_sync_helper +    async def test_sync_called(self, ctx, sync, kick): +        """Tests if silence command calls sync on a voice channel.""" +        channel = MockVoiceChannel() +        await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False) + +        sync.assert_awaited_once_with(self.cog, channel) +        kick.assert_not_called() + +    @voice_sync_helper +    async def test_kick_called(self, ctx, sync, kick): +        """Tests if silence command calls kick on a voice channel.""" +        channel = MockVoiceChannel() +        await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True) + +        kick.assert_awaited_once_with(channel) +        sync.assert_not_called() + +    @voice_sync_helper +    async def test_sync_not_called(self, ctx, sync, kick): +        """Tests that silence command does not call sync on a text channel.""" +        channel = MockTextChannel() +        await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=False) + +        sync.assert_not_called() +        kick.assert_not_called() + +    @voice_sync_helper +    async def test_kick_not_called(self, ctx, sync, kick): +        """Tests that silence command does not call kick on a text channel.""" +        channel = MockTextChannel() +        await self.cog.silence.callback(self.cog, ctx, channel, 10, kick=True) + +        sync.assert_not_called() +        kick.assert_not_called()      async def test_skipped_already_silenced(self):          """Permissions were not set and `False` was returned for an already silenced channel."""          subtests = ( -            (False, PermissionOverwrite(send_messages=False, add_reactions=False)), -            (True, PermissionOverwrite(send_messages=True, add_reactions=True)), -            (True, PermissionOverwrite(send_messages=False, add_reactions=False)), +            (False, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), +            (True, MockTextChannel(), PermissionOverwrite(send_messages=True, add_reactions=True)), +            (True, MockTextChannel(), PermissionOverwrite(send_messages=False, add_reactions=False)), +            (False, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)), +            (True, MockVoiceChannel(), PermissionOverwrite(connect=True, speak=True)), +            (True, MockVoiceChannel(), PermissionOverwrite(connect=False, speak=False)),          ) -        for contains, overwrite in subtests: -            with self.subTest(contains=contains, overwrite=overwrite): +        for contains, channel, overwrite in subtests: +            with self.subTest(contains=contains, is_text=isinstance(channel, MockTextChannel), overwrite=overwrite):                  self.cog.scheduler.__contains__.return_value = contains -                channel = MockTextChannel()                  channel.overwrites_for.return_value = overwrite                  self.assertFalse(await self.cog._set_silence_overwrites(channel))                  channel.set_permissions.assert_not_called() -    async def test_silenced_channel(self): +    async def test_silenced_text_channel(self):          """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" -        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) -        self.assertFalse(self.overwrite.send_messages) -        self.assertFalse(self.overwrite.add_reactions) -        self.channel.set_permissions.assert_awaited_once_with( +        self.assertTrue(await self.cog._set_silence_overwrites(self.text_channel)) +        self.assertFalse(self.text_overwrite.send_messages) +        self.assertFalse(self.text_overwrite.add_reactions) +        self.text_channel.set_permissions.assert_awaited_once_with(              self.cog._everyone_role, -            overwrite=self.overwrite +            overwrite=self.text_overwrite          ) -    async def test_preserved_other_overwrites(self): -        """Channel's other unrelated overwrites were not changed.""" -        prev_overwrite_dict = dict(self.overwrite) -        await self.cog._set_silence_overwrites(self.channel) -        new_overwrite_dict = dict(self.overwrite) +    async def test_silenced_voice_channel_speak(self): +        """Channel had `speak` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel)) +        self.assertFalse(self.voice_overwrite.speak) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite +        ) + +    async def test_silenced_voice_channel_full(self): +        """Channel had `speak` and `connect` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.voice_channel, kick=True)) +        self.assertFalse(self.voice_overwrite.speak or self.voice_overwrite.connect) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite +        ) + +    async def test_preserved_other_overwrites_text(self): +        """Channel's other unrelated overwrites were not changed for a text channel mute.""" +        prev_overwrite_dict = dict(self.text_overwrite) +        await self.cog._set_silence_overwrites(self.text_channel) +        new_overwrite_dict = dict(self.text_overwrite)          # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.          del prev_overwrite_dict['send_messages'] @@ -311,6 +567,20 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) +    async def test_preserved_other_overwrites_voice(self): +        """Channel's other unrelated overwrites were not changed for a voice channel mute.""" +        prev_overwrite_dict = dict(self.voice_overwrite) +        await self.cog._set_silence_overwrites(self.voice_channel) +        new_overwrite_dict = dict(self.voice_overwrite) + +        # Remove 'connect' & 'speak' keys because they were changed by the method. +        del prev_overwrite_dict['connect'] +        del prev_overwrite_dict['speak'] +        del new_overwrite_dict['connect'] +        del new_overwrite_dict['speak'] + +        self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) +      async def test_temp_not_added_to_notifier(self):          """Channel was not added to notifier if a duration was set for the silence."""          with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): @@ -320,7 +590,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      async def test_indefinite_added_to_notifier(self):          """Channel was added to notifier if a duration was not set for the silence."""          with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): -            await self.cog.silence.callback(self.cog, MockContext(), None) +            await self.cog.silence.callback(self.cog, MockContext(), None, None)              self.cog.notifier.add_channel.assert_called_once()      async def test_silenced_not_added_to_notifier(self): @@ -332,8 +602,8 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      async def test_cached_previous_overwrites(self):          """Channel's previous overwrites were cached."""          overwrite_json = '{"send_messages": true, "add_reactions": false}' -        await self.cog._set_silence_overwrites(self.channel) -        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) +        await self.cog._set_silence_overwrites(self.text_channel) +        self.cog.previous_overwrites.set.assert_awaited_once_with(self.text_channel.id, overwrite_json)      @autospec(silence, "datetime")      async def test_cached_unsilence_time(self, datetime_mock): @@ -343,7 +613,7 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):          timestamp = now_timestamp + duration * 60          datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) -        ctx = MockContext(channel=self.channel) +        ctx = MockContext(channel=self.text_channel)          await self.cog.silence.callback(self.cog, ctx, duration)          self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) @@ -351,26 +621,33 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):      async def test_cached_indefinite_time(self):          """A value of -1 was cached for a permanent silence.""" -        ctx = MockContext(channel=self.channel) -        await self.cog.silence.callback(self.cog, ctx, None) +        ctx = MockContext(channel=self.text_channel) +        await self.cog.silence.callback(self.cog, ctx, None, None)          self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)      async def test_scheduled_task(self):          """An unsilence task was scheduled.""" -        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) +        ctx = MockContext(channel=self.text_channel, invoke=mock.MagicMock())          await self.cog.silence.callback(self.cog, ctx, 5)          args = (300, ctx.channel.id, ctx.invoke.return_value)          self.cog.scheduler.schedule_later.assert_called_once_with(*args) -        ctx.invoke.assert_called_once_with(self.cog.unsilence) +        ctx.invoke.assert_called_once_with(self.cog.unsilence, channel=ctx.channel)      async def test_permanent_not_scheduled(self):          """A task was not scheduled for a permanent silence.""" -        ctx = MockContext(channel=self.channel) -        await self.cog.silence.callback(self.cog, ctx, None) +        ctx = MockContext(channel=self.text_channel) +        await self.cog.silence.callback(self.cog, ctx, None, None)          self.cog.scheduler.schedule_later.assert_not_called() +    async def test_indefinite_silence(self): +        """Test silencing a channel forever.""" +        with mock.patch.object(self.cog, "_schedule_unsilence") as unsilence: +            ctx = MockContext(channel=self.text_channel) +            await self.cog.silence.callback(self.cog, ctx, -1) +            unsilence.assert_awaited_once_with(ctx, ctx.channel, None) +  @autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)  class UnsilenceTests(unittest.IsolatedAsyncioTestCase): @@ -391,9 +668,13 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          self.cog.scheduler.__contains__.return_value = True          overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' -        self.channel = MockTextChannel() -        self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) -        self.channel.overwrites_for.return_value = self.overwrite +        self.text_channel = MockTextChannel() +        self.text_overwrite = PermissionOverwrite(send_messages=False, add_reactions=False) +        self.text_channel.overwrites_for.return_value = self.text_overwrite + +        self.voice_channel = MockVoiceChannel() +        self.voice_overwrite = PermissionOverwrite(connect=True, speak=True) +        self.voice_channel.overwrites_for.return_value = self.voice_overwrite      async def test_sent_correct_message(self):          """Appropriate failure/success message was sent by the command.""" @@ -401,88 +682,128 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):          test_cases = (              (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite),              (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), -            (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, self.text_overwrite),              (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),              (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),          ) -        for was_unsilenced, message, overwrite in test_cases: + +        targets = (None, MockTextChannel()) + +        for (was_unsilenced, message, overwrite), target in itertools.product(test_cases, targets):              ctx = MockContext() -            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): -                with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): -                    ctx.channel.overwrites_for.return_value = overwrite -                    await self.cog.unsilence.callback(self.cog, ctx) -                    ctx.channel.send.assert_called_once_with(message) +            ctx.channel.overwrites_for.return_value = overwrite +            if target: +                target.overwrites_for.return_value = overwrite + +            with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): +                with mock.patch.object(self.cog, "send_message") as send_message: +                    with self.subTest(was_unsilenced=was_unsilenced, overwrite=overwrite, target=target): +                        await self.cog.unsilence.callback(self.cog, ctx, channel=target) + +                        call_args = (message, ctx.channel, target or ctx.channel) +                        send_message.assert_awaited_once_with(*call_args, alert_target=was_unsilenced)      async def test_skipped_already_unsilenced(self):          """Permissions were not set and `False` was returned for an already unsilenced channel."""          self.cog.scheduler.__contains__.return_value = False          self.cog.previous_overwrites.get.return_value = None -        channel = MockTextChannel() -        self.assertFalse(await self.cog._unsilence(channel)) -        channel.set_permissions.assert_not_called() +        for channel in (MockVoiceChannel(), MockTextChannel()): +            with self.subTest(channel=channel): +                self.assertFalse(await self.cog._unsilence(channel)) +                channel.set_permissions.assert_not_called() -    async def test_restored_overwrites(self): -        """Channel's `send_message` and `add_reactions` overwrites were restored.""" -        await self.cog._unsilence(self.channel) -        self.channel.set_permissions.assert_awaited_once_with( +    async def test_restored_overwrites_text(self): +        """Text channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog._unsilence(self.text_channel) +        self.text_channel.set_permissions.assert_awaited_once_with(              self.cog._everyone_role, -            overwrite=self.overwrite, +            overwrite=self.text_overwrite, +        ) + +        # Recall that these values are determined by the fixture. +        self.assertTrue(self.text_overwrite.send_messages) +        self.assertFalse(self.text_overwrite.add_reactions) + +    async def test_restored_overwrites_voice(self): +        """Voice channel's `connect` and `speak` overwrites were restored.""" +        await self.cog._unsilence(self.voice_channel) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite,          )          # Recall that these values are determined by the fixture. -        self.assertTrue(self.overwrite.send_messages) -        self.assertFalse(self.overwrite.add_reactions) +        self.assertTrue(self.voice_overwrite.connect) +        self.assertTrue(self.voice_overwrite.speak) -    async def test_cache_miss_used_default_overwrites(self): -        """Both overwrites were set to None due previous values not being found in the cache.""" +    async def test_cache_miss_used_default_overwrites_text(self): +        """Text overwrites were set to None due previous values not being found in the cache."""          self.cog.previous_overwrites.get.return_value = None -        await self.cog._unsilence(self.channel) -        self.channel.set_permissions.assert_awaited_once_with( +        await self.cog._unsilence(self.text_channel) +        self.text_channel.set_permissions.assert_awaited_once_with(              self.cog._everyone_role, -            overwrite=self.overwrite, +            overwrite=self.text_overwrite, +        ) + +        self.assertIsNone(self.text_overwrite.send_messages) +        self.assertIsNone(self.text_overwrite.add_reactions) + +    async def test_cache_miss_used_default_overwrites_voice(self): +        """Voice overwrites were set to None due previous values not being found in the cache.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.voice_channel) +        self.voice_channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_voice_role, +            overwrite=self.voice_overwrite,          ) -        self.assertIsNone(self.overwrite.send_messages) -        self.assertIsNone(self.overwrite.add_reactions) +        self.assertIsNone(self.voice_overwrite.connect) +        self.assertIsNone(self.voice_overwrite.speak) -    async def test_cache_miss_sent_mod_alert(self): -        """A message was sent to the mod alerts channel.""" +    async def test_cache_miss_sent_mod_alert_text(self): +        """A message was sent to the mod alerts channel upon muting a text channel."""          self.cog.previous_overwrites.get.return_value = None +        await self.cog._unsilence(self.text_channel) +        self.cog._mod_alerts_channel.send.assert_awaited_once() -        await self.cog._unsilence(self.channel) +    async def test_cache_miss_sent_mod_alert_voice(self): +        """A message was sent to the mod alerts channel upon muting a voice channel.""" +        self.cog.previous_overwrites.get.return_value = None +        await self.cog._unsilence(MockVoiceChannel())          self.cog._mod_alerts_channel.send.assert_awaited_once()      async def test_removed_notifier(self):          """Channel was removed from `notifier`.""" -        await self.cog._unsilence(self.channel) -        self.cog.notifier.remove_channel.assert_called_once_with(self.channel) +        await self.cog._unsilence(self.text_channel) +        self.cog.notifier.remove_channel.assert_called_once_with(self.text_channel)      async def test_deleted_cached_overwrite(self):          """Channel was deleted from the overwrites cache.""" -        await self.cog._unsilence(self.channel) -        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) +        await self.cog._unsilence(self.text_channel) +        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.text_channel.id)      async def test_deleted_cached_time(self):          """Channel was deleted from the timestamp cache.""" -        await self.cog._unsilence(self.channel) -        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) +        await self.cog._unsilence(self.text_channel) +        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.text_channel.id)      async def test_cancelled_task(self):          """The scheduled unsilence task should be cancelled.""" -        await self.cog._unsilence(self.channel) -        self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) +        await self.cog._unsilence(self.text_channel) +        self.cog.scheduler.cancel.assert_called_once_with(self.text_channel.id) -    async def test_preserved_other_overwrites(self): -        """Channel's other unrelated overwrites were not changed, including cache misses.""" +    async def test_preserved_other_overwrites_text(self): +        """Text channel's other unrelated overwrites were not changed, including cache misses."""          for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):              with self.subTest(overwrite_json=overwrite_json):                  self.cog.previous_overwrites.get.return_value = overwrite_json -                prev_overwrite_dict = dict(self.overwrite) -                await self.cog._unsilence(self.channel) -                new_overwrite_dict = dict(self.overwrite) +                prev_overwrite_dict = dict(self.text_overwrite) +                await self.cog._unsilence(self.text_channel) +                new_overwrite_dict = dict(self.text_overwrite)                  # Remove these keys because they were modified by the unsilence.                  del prev_overwrite_dict['send_messages'] @@ -491,3 +812,114 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):                  del new_overwrite_dict['add_reactions']                  self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_preserved_other_overwrites_voice(self): +        """Voice channel's other unrelated overwrites were not changed, including cache misses.""" +        for overwrite_json in ('{"connect": true, "speak": true}', None): +            with self.subTest(overwrite_json=overwrite_json): +                self.cog.previous_overwrites.get.return_value = overwrite_json + +                prev_overwrite_dict = dict(self.voice_overwrite) +                await self.cog._unsilence(self.voice_channel) +                new_overwrite_dict = dict(self.voice_overwrite) + +                # Remove these keys because they were modified by the unsilence. +                del prev_overwrite_dict['connect'] +                del prev_overwrite_dict['speak'] +                del new_overwrite_dict['connect'] +                del new_overwrite_dict['speak'] + +                self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_unsilence_role(self): +        """Tests unsilence_wrapper applies permission to the correct role.""" +        test_cases = ( +            (MockTextChannel(), self.cog.bot.get_guild(Guild.id).default_role), +            (MockVoiceChannel(), self.cog.bot.get_guild(Guild.id).get_role(Roles.voice_verified)) +        ) + +        for channel, role in test_cases: +            with self.subTest(channel=channel, role=role): +                await self.cog._unsilence_wrapper(channel, MockContext()) +                channel.overwrites_for.assert_called_with(role) + + +class SendMessageTests(unittest.IsolatedAsyncioTestCase): +    """Unittests for the send message helper function.""" + +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) + +        self.text_channels = [MockTextChannel() for _ in range(2)] +        self.bot.get_channel.return_value = self.text_channels[1] + +        self.voice_channel = MockVoiceChannel() + +    async def test_send_to_channel(self): +        """Tests a basic case for the send function.""" +        message = "Test basic message." +        await self.cog.send_message(message, *self.text_channels, alert_target=False) + +        self.text_channels[0].send.assert_awaited_once_with(message) +        self.text_channels[1].send.assert_not_called() + +    async def test_send_to_multiple_channels(self): +        """Tests sending messages to two channels.""" +        message = "Test basic message." +        await self.cog.send_message(message, *self.text_channels, alert_target=True) + +        self.text_channels[0].send.assert_awaited_once_with(message) +        self.text_channels[1].send.assert_awaited_once_with(message) + +    async def test_duration_replacement(self): +        """Tests that the channel name was set correctly for one target channel.""" +        message = "Current. The following should be replaced: {channel}." +        await self.cog.send_message(message, *self.text_channels, alert_target=False) + +        updated_message = message.format(channel=self.text_channels[0].mention) +        self.text_channels[0].send.assert_awaited_once_with(updated_message) +        self.text_channels[1].send.assert_not_called() + +    async def test_name_replacement_multiple_channels(self): +        """Tests that the channel name was set correctly for two channels.""" +        message = "Current. The following should be replaced: {channel}." +        await self.cog.send_message(message, *self.text_channels, alert_target=True) + +        self.text_channels[0].send.assert_awaited_once_with(message.format(channel=self.text_channels[0].mention)) +        self.text_channels[1].send.assert_awaited_once_with(message.format(channel="current channel")) + +    async def test_silence_voice(self): +        """Tests that the correct message was sent when a voice channel is muted without alerting.""" +        message = "This should show up just here." +        await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=False) +        self.text_channels[0].send.assert_awaited_once_with(message) +        self.text_channels[1].send.assert_not_called() + +    async def test_silence_voice_alert(self): +        """Tests that the correct message was sent when a voice channel is muted with alerts.""" +        with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: +            mock_voice_channels.get.return_value = self.text_channels[1].id + +            message = "This should show up as {channel}." +            await self.cog.send_message(message, self.text_channels[0], self.voice_channel, alert_target=True) + +        updated_message = message.format(channel=self.voice_channel.mention) +        self.text_channels[0].send.assert_awaited_once_with(updated_message) +        self.text_channels[1].send.assert_awaited_once_with(updated_message) + +        mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) + +    async def test_silence_voice_sibling_channel(self): +        """Tests silencing a voice channel from the related text channel.""" +        with unittest.mock.patch.object(silence, "VOICE_CHANNELS") as mock_voice_channels: +            mock_voice_channels.get.return_value = self.text_channels[1].id + +            message = "This should show up as {channel}." +            await self.cog.send_message(message, self.text_channels[1], self.voice_channel, alert_target=True) + +            updated_message = message.format(channel=self.voice_channel.mention) +            self.text_channels[1].send.assert_awaited_once_with(updated_message) + +            mock_voice_channels.get.assert_called_once_with(self.voice_channel.id) +            self.bot.get_channel.assert_called_once_with(self.text_channels[1].id) diff --git a/tests/bot/exts/utils/test_jams.py b/tests/bot/exts/utils/test_jams.py deleted file mode 100644 index 85d6a1173..000000000 --- a/tests/bot/exts/utils/test_jams.py +++ /dev/null @@ -1,171 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, create_autospec - -from discord import CategoryChannel - -from bot.constants import Roles -from bot.exts.utils import jams -from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel - - -def get_mock_category(channel_count: int, name: str) -> CategoryChannel: -    """Return a mocked code jam category.""" -    category = create_autospec(CategoryChannel, spec_set=True, instance=True) -    category.name = name -    category.channels = [MockTextChannel() for _ in range(channel_count)] - -    return category - - -class JamCreateTeamTests(unittest.IsolatedAsyncioTestCase): -    """Tests for `createteam` command.""" - -    def setUp(self): -        self.bot = MockBot() -        self.admin_role = MockRole(name="Admins", id=Roles.admins) -        self.command_user = MockMember([self.admin_role]) -        self.guild = MockGuild([self.admin_role]) -        self.ctx = MockContext(bot=self.bot, author=self.command_user, guild=self.guild) -        self.cog = jams.CodeJams(self.bot) - -    async def test_too_small_amount_of_team_members_passed(self): -        """Should `ctx.send` and exit early when too small amount of members.""" -        for case in (1, 2): -            with self.subTest(amount_of_members=case): -                self.cog.create_channels = AsyncMock() -                self.cog.add_roles = AsyncMock() - -                self.ctx.reset_mock() -                members = (MockMember() for _ in range(case)) -                await self.cog.createteam(self.cog, self.ctx, "foo", members) - -                self.ctx.send.assert_awaited_once() -                self.cog.create_channels.assert_not_awaited() -                self.cog.add_roles.assert_not_awaited() - -    async def test_duplicate_members_provided(self): -        """Should `ctx.send` and exit early because duplicate members provided and total there is only 1 member.""" -        self.cog.create_channels = AsyncMock() -        self.cog.add_roles = AsyncMock() - -        member = MockMember() -        await self.cog.createteam(self.cog, self.ctx, "foo", (member for _ in range(5))) - -        self.ctx.send.assert_awaited_once() -        self.cog.create_channels.assert_not_awaited() -        self.cog.add_roles.assert_not_awaited() - -    async def test_result_sending(self): -        """Should call `ctx.send` when everything goes right.""" -        self.cog.create_channels = AsyncMock() -        self.cog.add_roles = AsyncMock() - -        members = [MockMember() for _ in range(5)] -        await self.cog.createteam(self.cog, self.ctx, "foo", members) - -        self.cog.create_channels.assert_awaited_once() -        self.cog.add_roles.assert_awaited_once() -        self.ctx.send.assert_awaited_once() - -    async def test_category_doesnt_exist(self): -        """Should create a new code jam category.""" -        subtests = ( -            [], -            [get_mock_category(jams.MAX_CHANNELS - 1, jams.CATEGORY_NAME)], -            [get_mock_category(jams.MAX_CHANNELS - 2, "other")], -        ) - -        for categories in subtests: -            self.guild.reset_mock() -            self.guild.categories = categories - -            with self.subTest(categories=categories): -                actual_category = await self.cog.get_category(self.guild) - -                self.guild.create_category_channel.assert_awaited_once() -                category_overwrites = self.guild.create_category_channel.call_args[1]["overwrites"] - -                self.assertFalse(category_overwrites[self.guild.default_role].read_messages) -                self.assertTrue(category_overwrites[self.guild.me].read_messages) -                self.assertEqual(self.guild.create_category_channel.return_value, actual_category) - -    async def test_category_channel_exist(self): -        """Should not try to create category channel.""" -        expected_category = get_mock_category(jams.MAX_CHANNELS - 2, jams.CATEGORY_NAME) -        self.guild.categories = [ -            get_mock_category(jams.MAX_CHANNELS - 2, "other"), -            expected_category, -            get_mock_category(0, jams.CATEGORY_NAME), -        ] - -        actual_category = await self.cog.get_category(self.guild) -        self.assertEqual(expected_category, actual_category) - -    async def test_channel_overwrites(self): -        """Should have correct permission overwrites for users and roles.""" -        leader = MockMember() -        members = [leader] + [MockMember() for _ in range(4)] -        overwrites = self.cog.get_overwrites(members, self.guild) - -        # Leader permission overwrites -        self.assertTrue(overwrites[leader].manage_messages) -        self.assertTrue(overwrites[leader].read_messages) -        self.assertTrue(overwrites[leader].manage_webhooks) -        self.assertTrue(overwrites[leader].connect) - -        # Other members permission overwrites -        for member in members[1:]: -            self.assertTrue(overwrites[member].read_messages) -            self.assertTrue(overwrites[member].connect) - -        # Everyone role overwrite -        self.assertFalse(overwrites[self.guild.default_role].read_messages) -        self.assertFalse(overwrites[self.guild.default_role].connect) - -    async def test_team_channels_creation(self): -        """Should create new voice and text channel for team.""" -        members = [MockMember() for _ in range(5)] - -        self.cog.get_overwrites = MagicMock() -        self.cog.get_category = AsyncMock() -        self.ctx.guild.create_text_channel.return_value = MockTextChannel(mention="foobar-channel") -        actual = await self.cog.create_channels(self.guild, "my-team", members) - -        self.assertEqual("foobar-channel", actual) -        self.cog.get_overwrites.assert_called_once_with(members, self.guild) -        self.cog.get_category.assert_awaited_once_with(self.guild) - -        self.guild.create_text_channel.assert_awaited_once_with( -            "my-team", -            overwrites=self.cog.get_overwrites.return_value, -            category=self.cog.get_category.return_value -        ) -        self.guild.create_voice_channel.assert_awaited_once_with( -            "My Team", -            overwrites=self.cog.get_overwrites.return_value, -            category=self.cog.get_category.return_value -        ) - -    async def test_jam_roles_adding(self): -        """Should add team leader role to leader and jam role to every team member.""" -        leader_role = MockRole(name="Team Leader") -        jam_role = MockRole(name="Jammer") -        self.guild.get_role.side_effect = [leader_role, jam_role] - -        leader = MockMember() -        members = [leader] + [MockMember() for _ in range(4)] -        await self.cog.add_roles(self.guild, members) - -        leader.add_roles.assert_any_await(leader_role) -        for member in members: -            member.add_roles.assert_any_await(jam_role) - - -class CodeJamSetup(unittest.TestCase): -    """Test for `setup` function of `CodeJam` cog.""" - -    def test_setup(self): -        """Should call `bot.add_cog`.""" -        bot = MockBot() -        jams.setup(bot) -        bot.add_cog.assert_called_once() diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index 6444532f2..f8805ac48 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -2,12 +2,14 @@ from typing import Iterable  from bot.rules import mentions  from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockMessage +from tests.helpers import MockMember, MockMessage -def make_msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_user_mentions: int, total_bot_mentions: int = 0) -> MockMessage:      """Makes a message with `total_mentions` mentions.""" -    return MockMessage(author=author, mentions=list(range(total_mentions))) +    user_mentions = [MockMember() for _ in range(total_user_mentions)] +    bot_mentions = [MockMember(bot=True) for _ in range(total_bot_mentions)] +    return MockMessage(author=author, mentions=user_mentions+bot_mentions)  class TestMentions(RuleTest): @@ -48,11 +50,27 @@ class TestMentions(RuleTest):                  [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",),                  4, -            ) +            ), +            DisallowedCase( +                [make_msg("bob", 3, 1)], +                ("bob",), +                3, +            ),          )          await self.run_disallowed(cases) +    async def test_ignore_bot_mentions(self): +        """Messages with an allowed amount of mentions, also containing bot mentions.""" +        cases = ( +            [make_msg("bob", 0, 3)], +            [make_msg("bob", 2, 1)], +            [make_msg("bob", 1, 2), make_msg("bob", 1, 2)], +            [make_msg("bob", 1, 5), make_msg("alice", 2, 5)] +        ) + +        await self.run_allowed(cases) +      def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:          last_message = case.recent_messages[0]          return tuple( diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index d0d7af1ba..f84de453d 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -11,7 +11,6 @@ from bot.converters import (      HushDurationConverter,      ISODateTime,      PackageName, -    TagContentConverter,  ) @@ -25,30 +24,6 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):          cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') -    async def test_tag_content_converter_for_valid(self): -        """TagContentConverter should return correct values for valid input.""" -        test_values = ( -            ('hello', 'hello'), -            ('  h ello  ', 'h ello'), -        ) - -        for content, expected_conversion in test_values: -            with self.subTest(content=content, expected_conversion=expected_conversion): -                conversion = await TagContentConverter.convert(self.context, content) -                self.assertEqual(conversion, expected_conversion) - -    async def test_tag_content_converter_for_invalid(self): -        """TagContentConverter should raise the proper exception for invalid input.""" -        test_values = ( -            ('', "Tag contents should not be empty, or filled with whitespace."), -            ('   ', "Tag contents should not be empty, or filled with whitespace."), -        ) - -        for value, exception_message in test_values: -            with self.subTest(tag_content=value, exception_message=exception_message): -                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): -                    await TagContentConverter.convert(self.context, value) -      async def test_package_name_for_valid(self):          """PackageName returns valid package names unchanged."""          test_values = ('foo', 'le_mon', 'num83r') @@ -262,7 +237,7 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):              ("10", 10),              ("5m", 5),              ("5M", 5), -            ("forever", None), +            ("forever", -1),          )          converter = HushDurationConverter()          for minutes_string, expected_minutes in test_values: diff --git a/tests/bot/utils/test_message_cache.py b/tests/bot/utils/test_message_cache.py new file mode 100644 index 000000000..04bfd28d1 --- /dev/null +++ b/tests/bot/utils/test_message_cache.py @@ -0,0 +1,214 @@ +import unittest + +from bot.utils.message_cache import MessageCache +from tests.helpers import MockMessage + + +# noinspection SpellCheckingInspection +class TestMessageCache(unittest.TestCase): +    """Tests for the MessageCache class in the `bot.utils.caching` module.""" + +    def test_first_append_sets_the_first_value(self): +        """Test if the first append adds the message to the first cell.""" +        cache = MessageCache(maxlen=10) +        message = MockMessage() + +        cache.append(message) + +        self.assertEqual(cache[0], message) + +    def test_append_adds_in_the_right_order(self): +        """Test if two appends are added in the same order if newest_first is False, or in reverse order otherwise.""" +        messages = [MockMessage(), MockMessage()] + +        cache = MessageCache(maxlen=10, newest_first=False) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages, list(cache)) + +        cache = MessageCache(maxlen=10, newest_first=True) +        for msg in messages: +            cache.append(msg) +        self.assertListEqual(messages[::-1], list(cache)) + +    def test_appending_over_maxlen_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages.""" +        cache = MessageCache(maxlen=2) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[1:], list(cache)) + +    def test_appending_over_maxlen_with_newest_first_removes_oldest(self): +        """Test if three appends to a 2-cell cache leave the two newest messages if newest_first is True.""" +        cache = MessageCache(maxlen=2, newest_first=True) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) + +        self.assertListEqual(messages[:0:-1], list(cache)) + +    def test_pop_removes_from_the_end(self): +        """Test if a pop removes the right-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.pop() + +        self.assertEqual(msg, messages[-1]) +        self.assertListEqual(messages[:-1], list(cache)) + +    def test_popleft_removes_from_the_beginning(self): +        """Test if a popleft removes the left-most message.""" +        cache = MessageCache(maxlen=3) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        msg = cache.popleft() + +        self.assertEqual(msg, messages[0]) +        self.assertListEqual(messages[1:], list(cache)) + +    def test_clear(self): +        """Test if a clear makes the cache empty.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] + +        for msg in messages: +            cache.append(msg) +        cache.clear() + +        self.assertListEqual(list(cache), []) +        self.assertEqual(len(cache), 0) + +    def test_get_message_returns_the_message(self): +        """Test if get_message returns the cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertEqual(cache.get_message(1234), message) + +    def test_get_message_returns_none(self): +        """Test if get_message returns None for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIsNone(cache.get_message(4321)) + +    def test_update_replaces_old_element(self): +        """Test if an update replaced the old message with the same ID.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) +        message = MockMessage(id=1234) +        cache.update(message) + +        self.assertIs(cache.get_message(1234), message) +        self.assertEqual(len(cache), 1) + +    def test_contains_returns_true_for_cached_message(self): +        """Test if contains returns True for an ID of a cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertIn(1234, cache) + +    def test_contains_returns_false_for_non_cached_message(self): +        """Test if contains returns False for an ID of a non-cached message.""" +        cache = MessageCache(maxlen=5) +        message = MockMessage(id=1234) + +        cache.append(message) + +        self.assertNotIn(4321, cache) + +    def test_indexing(self): +        """Test if the cache returns the correct messages by index.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(5)] + +        for msg in messages: +            cache.append(msg) + +        for current_loop in range(-5, 5): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(cache[current_loop], messages[current_loop]) + +    def test_bad_index_raises_index_error(self): +        """Test if the cache raises IndexError for invalid indices.""" +        cache = MessageCache(maxlen=5) +        messages = [MockMessage() for _ in range(3)] +        test_cases = (-10, -4, 3, 4, 5) + +        for msg in messages: +            cache.append(msg) + +        for current_loop in test_cases: +            with self.subTest(current_loop=current_loop): +                with self.assertRaises(IndexError): +                    cache[current_loop] + +    def test_slicing_with_unfilled_cache(self): +        """Test if slicing returns the correct messages if the cache is not yet fully filled.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size // 3 * 2)] + +            for msg in messages: +                cache.append(msg) + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_slicing_with_overfilled_cache(self): +        """Test if slicing returns the correct messages if the cache was appended with more messages it can contain.""" +        sizes = (5, 10, 55, 101) + +        slices = ( +            slice(None), slice(2, None), slice(None, 2), slice(None, None, 2), slice(None, None, 3), slice(-1, 2), +            slice(-1, 3000), slice(-3, -1), slice(-10, 3), slice(-10, 4, 2), slice(None, None, -1), slice(None, 3, -2), +            slice(None, None, -3), slice(-1, -10, -2), slice(-3, -7, -1) +        ) + +        for size in sizes: +            cache = MessageCache(maxlen=size) +            messages = [MockMessage() for _ in range(size * 3 // 2)] + +            for msg in messages: +                cache.append(msg) +            messages = messages[size // 2:] + +            for slice_ in slices: +                with self.subTest(current_loop=(size, slice_)): +                    self.assertListEqual(cache[slice_], messages[slice_]) + +    def test_length(self): +        """Test if len returns the correct number of items in the cache.""" +        cache = MessageCache(maxlen=5) + +        for current_loop in range(10): +            with self.subTest(current_loop=current_loop): +                self.assertEqual(len(cache), min(current_loop, 5)) +                cache.append(MockMessage()) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 115ddfb0d..8edffd1c9 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -52,7 +52,7 @@ class TimeTests(unittest.TestCase):      def test_format_infraction(self):          """Testing format_infraction.""" -        self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') +        self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '<t:1576108860:f>')      def test_format_infraction_with_duration_none_expiry(self):          """format_infraction_with_duration should work for None expiry.""" @@ -72,10 +72,10 @@ class TimeTests(unittest.TestCase):      def test_format_infraction_with_duration_custom_units(self):          """format_infraction_with_duration should work for custom max_units."""          test_cases = ( -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, -             '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), -            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, -             '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') +            ('3000-12-12T00:01:00Z', datetime(3000, 12, 11, 12, 5, 5), 6, +             '<t:32533488060:f> (11 hours, 55 minutes and 55 seconds)'), +            ('3000-11-23T20:09:00Z', datetime(3000, 4, 25, 20, 15), 20, +             '<t:32531918940:f> (6 months, 28 days, 23 hours and 54 minutes)')          )          for expiry, date_from, max_units, expected in test_cases: @@ -85,16 +85,16 @@ class TimeTests(unittest.TestCase):      def test_format_infraction_with_duration_normal_usage(self):          """format_infraction_with_duration should work for normal usage, across various durations."""          test_cases = ( -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), -            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), -            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), -            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), -            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), -            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), -            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '<t:1576108860:f> (12 hours and 55 seconds)'), +            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '<t:1576108860:f> (12 hours)'), +            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '<t:1576108800:f> (1 minute)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '<t:1574539740:f> (7 days and 23 hours)'), +            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '<t:1574539740:f> (6 months and 28 days)'), +            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '<t:1574542680:f> (5 minutes)'), +            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '<t:1574553600:f> (1 minute)'), +            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '<t:1574553540:f> (2 years and 4 months)'),              ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, -             '2019-11-23 23:59 (9 minutes and 55 seconds)'), +             '<t:1574553540:f> (9 minutes and 55 seconds)'),              (None, datetime(2019, 11, 23, 23, 49, 5), 2, None),          ) @@ -104,45 +104,30 @@ class TimeTests(unittest.TestCase):      def test_until_expiration_with_duration_none_expiry(self):          """until_expiration should work for None expiry.""" -        test_cases = ( -            (None, None, None, None), - -            # To make sure that now and max_units are not touched -            (None, 'Why hello there!', None, None), -            (None, None, float('inf'), None), -            (None, 'Why hello there!', float('inf'), None), -        ) - -        for expiry, now, max_units, expected in test_cases: -            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): -                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) +        self.assertEqual(time.until_expiration(None), None)      def test_until_expiration_with_duration_custom_units(self):          """until_expiration should work for custom max_units."""          test_cases = ( -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), -            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') +            ('3000-12-12T00:01:00Z', '<t:32533488060:R>'), +            ('3000-11-23T20:09:00Z', '<t:32531918940:R>')          ) -        for expiry, now, max_units, expected in test_cases: -            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): -                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) +        for expiry, expected in test_cases: +            with self.subTest(expiry=expiry, expected=expected): +                self.assertEqual(time.until_expiration(expiry,), expected)      def test_until_expiration_normal_usage(self):          """until_expiration should work for normal usage, across various durations."""          test_cases = ( -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), -            ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), -            ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), -            ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), -            ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), -            ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), -            ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), -            ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), -            ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), -            (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), +            ('3000-12-12T00:01:00Z', '<t:32533488060:R>'), +            ('3000-12-12T00:01:00Z', '<t:32533488060:R>'), +            ('3000-12-12T00:00:00Z', '<t:32533488000:R>'), +            ('3000-11-23T20:09:00Z', '<t:32531918940:R>'), +            ('3000-11-23T20:09:00Z', '<t:32531918940:R>'), +            (None, None),          ) -        for expiry, now, max_units, expected in test_cases: -            with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): -                self.assertEqual(time.until_expiration(expiry, now, max_units), expected) +        for expiry, expected in test_cases: +            with self.subTest(expiry=expiry, expected=expected): +                self.assertEqual(time.until_expiration(expiry), expected) diff --git a/tests/helpers.py b/tests/helpers.py index e3dc5fe5b..3978076ed 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,6 @@ from bot.async_stats import AsyncStatsClient  from bot.bot import Bot  from tests._autospec import autospec  # noqa: F401 other modules import it via this module -  for logger in logging.Logger.manager.loggerDict.values():      # Set all loggers to CRITICAL by default to prevent screen clutter during testing @@ -320,7 +319,10 @@ channel_data = {  }  state = unittest.mock.MagicMock()  guild = unittest.mock.MagicMock() -channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) +text_channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + +channel_data["type"] = "VoiceChannel" +voice_channel_instance = discord.VoiceChannel(state=state, guild=guild, data=channel_data)  class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): @@ -330,7 +332,24 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      Instances of this class will follow the specifications of `discord.TextChannel` instances. For      more information, see the `MockGuild` docstring.      """ -    spec_set = channel_instance +    spec_set = text_channel_instance + +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} +        super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + +        if 'mention' not in kwargs: +            self.mention = f"#{self.name}" + + +class MockVoiceChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock VoiceChannel objects. + +    Instances of this class will follow the specifications of `discord.VoiceChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ +    spec_set = voice_channel_instance      def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} @@ -361,6 +380,27 @@ class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):          super().__init__(**collections.ChainMap(kwargs, default_kwargs)) +# Create CategoryChannel instance to get a realistic MagicMock of `discord.CategoryChannel` +category_channel_data = { +    'id': 1, +    'type': discord.ChannelType.category, +    'name': 'category', +    'position': 1, +} + +state = unittest.mock.MagicMock() +guild = unittest.mock.MagicMock() +category_channel_instance = discord.CategoryChannel( +    state=state, guild=guild, data=category_channel_data +) + + +class MockCategoryChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id)} +        super().__init__(**collections.ChainMap(default_kwargs, kwargs)) + +  # Create a Message instance to get a realistic MagicMock of `discord.Message`  message_data = {      'id': 1, @@ -403,6 +443,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) +        self.message = kwargs.get('message', MockMessage())          self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) | 
