diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/cogs/moderation/test_infractions.py | 55 | ||||
| -rw-r--r-- | tests/bot/cogs/moderation/test_modlog.py | 29 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_cog.py | 3 | ||||
| -rw-r--r-- | tests/bot/cogs/sync/test_users.py | 2 | ||||
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 159 | ||||
| -rw-r--r-- | tests/bot/cogs/test_cogs.py | 4 | ||||
| -rw-r--r-- | tests/bot/cogs/test_duck_pond.py | 2 | ||||
| -rw-r--r-- | tests/bot/cogs/test_information.py | 19 | ||||
| -rw-r--r-- | tests/bot/cogs/test_snekbox.py | 26 | ||||
| -rw-r--r-- | tests/bot/test_constants.py | 43 | ||||
| -rw-r--r-- | tests/bot/test_decorators.py | 147 | ||||
| -rw-r--r-- | tests/bot/utils/test_checks.py | 52 | ||||
| -rw-r--r-- | tests/bot/utils/test_redis_cache.py | 273 | ||||
| -rw-r--r-- | tests/helpers.py | 57 | 
14 files changed, 824 insertions, 47 deletions
| diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py new file mode 100644 index 000000000..da4e92ccc --- /dev/null +++ b/tests/bot/cogs/moderation/test_infractions.py @@ -0,0 +1,55 @@ +import textwrap +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from bot.cogs.moderation.infractions import Infractions +from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole + + +class TruncationTests(unittest.IsolatedAsyncioTestCase): +    """Tests for ban and kick command reason truncation.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = Infractions(self.bot) +        self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10)) +        self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0)) +        self.guild = MockGuild(id=4567) +        self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild) + +    @patch("bot.cogs.moderation.utils.get_active_infraction") +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock): +        """Should truncate reason for `ctx.guild.ban`.""" +        get_active_mock.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.bot.get_cog.return_value = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.ctx.guild.ban = Mock() + +        await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000) +        self.ctx.guild.ban.assert_called_once_with( +            self.target, +            reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."), +            delete_message_days=0 +        ) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value +        ) + +    @patch("bot.cogs.moderation.utils.post_infraction") +    async def test_apply_kick_reason_truncation(self, post_infraction_mock): +        """Should truncate reason for `Member.kick`.""" +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.cog.apply_infraction = AsyncMock() +        self.cog.mod_log.ignore = Mock() +        self.target.kick = Mock() + +        await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000) +        self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="...")) +        self.cog.apply_infraction.assert_awaited_once_with( +            self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value +        ) diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py new file mode 100644 index 000000000..f2809f40a --- /dev/null +++ b/tests/bot/cogs/moderation/test_modlog.py @@ -0,0 +1,29 @@ +import unittest + +import discord + +from bot.cogs.moderation.modlog import ModLog +from tests.helpers import MockBot, MockTextChannel + + +class ModLogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for moderation logs.""" + +    def setUp(self): +        self.bot = MockBot() +        self.cog = ModLog(self.bot) +        self.channel = MockTextChannel() + +    async def test_log_entry_description_truncation(self): +        """Test that embed description for ModLog entry is truncated.""" +        self.bot.get_channel.return_value = self.channel +        await self.cog.send_log_message( +            icon_url="foo", +            colour=discord.Colour.blue(), +            title="bar", +            text="foo bar" * 3000 +        ) +        embed = self.channel.send.call_args[1]["embed"] +        self.assertEqual( +            embed.description, ("foo bar" * 3000)[:2045] + "..." +        ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py index 81398c61f..14fd909c4 100644 --- a/tests/bot/cogs/sync/test_cog.py +++ b/tests/bot/cogs/sync/test_cog.py @@ -247,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase):          before_data = {              "name": "old name",              "discriminator": "1234", -            "avatar": "old avatar",              "bot": False,          }          subtests = (              (True, "name", "name", "new name", "new name"),              (True, "discriminator", "discriminator", "8765", 8765), -            (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),              (False, "bot", "bot", True, True),          ) @@ -295,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase):          )          data = { -            "avatar_hash": member.avatar,              "discriminator": int(member.discriminator),              "id": member.id,              "in_guild": True, diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index 818883012..002a947ad 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -10,7 +10,6 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("avatar_hash", None)      kwargs.setdefault("roles", (666,))      kwargs.setdefault("in_guild", True) @@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          for member in members:              member = member.copy() -            member["avatar"] = member.pop("avatar_hash")              del member["in_guild"]              mock_member = helpers.MockMember(**member) diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py new file mode 100644 index 000000000..f219fc1ba --- /dev/null +++ b/tests/bot/cogs/test_antimalware.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from discord import NotFound + +from bot.cogs import antimalware +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole + +MODULE = "bot.cogs.antimalware" + + +@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): +    """Test the AntiMalware cog.""" + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() +        self.cog = antimalware.AntiMalware(self.bot) +        self.message = MockMessage() + +    async def test_message_with_allowed_attachment(self): +        """Messages with allowed extensions should not be deleted""" +        attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_message_without_attachment(self): +        """Messages without attachments should result in no action.""" +        await self.cog.on_message(self.message) +        self.message.delete.assert_not_called() + +    async def test_direct_message_with_attachment(self): +        """Direct messages should have no action taken.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.guild = None + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_message_with_illegal_extension_gets_deleted(self): +        """A message containing an illegal extension should send an embed.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_called_once() + +    async def test_message_send_by_staff(self): +        """A message send by a member of staff should be ignored.""" +        staff_role = MockRole(id=STAFF_ROLES[0]) +        self.message.author.roles.append(staff_role) +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        await self.cog.on_message(self.message) + +        self.message.delete.assert_not_called() + +    async def test_python_file_redirect_embed_description(self): +        """A message containing a .py file should result in an embed redirecting the user to our paste site""" +        attachment = MockAttachment(filename="python.py") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") + +        self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION) + +    async def test_txt_file_redirect_embed_description(self): +        """A message containing a .txt file should result in the correct embed.""" +        attachment = MockAttachment(filename="python.txt") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.TXT_EMBED_DESCRIPTION = Mock() +        antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        cmd_channel = self.bot.get_channel(Channels.bot_commands) + +        self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) +        antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) + +    async def test_other_disallowed_extention_embed_description(self): +        """Test the description for a non .py/.txt disallowed extension.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.channel.send = AsyncMock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock() +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test" + +        await self.cog.on_message(self.message) +        self.message.channel.send.assert_called_once() +        args, kwargs = self.message.channel.send.call_args +        embed = kwargs.pop("embed") +        meta_channel = self.bot.get_channel(Channels.meta) + +        self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) +        antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( +            blocked_extensions_str=".disallowed", +            meta_channel_mention=meta_channel.mention +        ) + +    async def test_removing_deleted_message_logs(self): +        """Removing an already deleted message logs the correct message""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] +        self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) +        self.message.delete.assert_called_once() + +    async def test_message_with_illegal_attachment_logs(self): +        """Deleting a message with an illegal attachment should result in a log.""" +        attachment = MockAttachment(filename="python.disallowed") +        self.message.attachments = [attachment] + +        with self.assertLogs(logger=antimalware.log, level="INFO"): +            await self.cog.on_message(self.message) + +    async def test_get_disallowed_extensions(self): +        """The return value should include all non-whitelisted extensions.""" +        test_values = ( +            ([], []), +            (AntiMalwareConfig.whitelist, []), +            ([".first"], []), +            ([".first", ".disallowed"], [".disallowed"]), +            ([".disallowed"], [".disallowed"]), +            ([".disallowed", ".illegal"], [".disallowed", ".illegal"]), +        ) + +        for extensions, expected_disallowed_extensions in test_values: +            with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): +                self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] +                disallowed_extensions = self.cog.get_disallowed_extensions(self.message) +                self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) + + +class AntiMalwareSetupTests(unittest.TestCase): +    """Tests setup of the `AntiMalware` cog.""" + +    def test_setup(self): +        """Setup of the extension should call add_cog.""" +        bot = MockBot() +        antimalware.setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index 39f6492cb..fdda59a8f 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -31,7 +31,7 @@ class CommandNameTests(unittest.TestCase):      def walk_modules() -> t.Iterator[ModuleType]:          """Yield imported modules from the bot.cogs subpackage."""          def on_error(name: str) -> t.NoReturn: -            raise ImportError(name=name) +            raise ImportError(name=name)  # pragma: no cover          # The mock prevents asyncio.get_event_loop() from being called.          with mock.patch("discord.ext.tasks.loop"): @@ -71,7 +71,7 @@ class CommandNameTests(unittest.TestCase):              for name in self.get_qualified_names(cmd):                  with self.subTest(cmd=func_name, name=name): -                    if name in all_names: +                    if name in all_names:  # pragma: no cover                          conflicts = ", ".join(all_names.get(name, ""))                          self.fail(                              f"Name '{name}' of the command {func_name} conflicts with {conflicts}." diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index 7e6bfc748..a8c0107c6 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):          self.assertEqual(cog.bot, bot)          self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) -        bot.loop.create_loop.called_once_with(cog.fetch_webhook()) +        bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())      def test_fetch_webhook_succeeds_without_connectivity_issues(self):          """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 3c26374f5..79c0e0ad3 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -7,10 +7,9 @@ import discord  from bot import constants  from bot.cogs import information -from bot.decorators import InChannelCheckFailure +from bot.utils.checks import InWhitelistCheckFailure  from tests import helpers -  COG_PATH = "bot.cogs.information.Information" @@ -149,14 +148,18 @@ class InformationCogTests(unittest.TestCase):                  Voice region: {self.ctx.guild.region}                  Features: {', '.join(self.ctx.guild.features)} -                **Counts** -                Members: {self.ctx.guild.member_count:,} -                Roles: {len(self.ctx.guild.roles)} +                **Channel counts**                  Category channels: 1                  Text channels: 1                  Voice channels: 1 +                Staff channels: 0 + +                **Member counts** +                Members: {self.ctx.guild.member_count:,} +                Staff members: 0 +                Roles: {len(self.ctx.guild.roles)} -                **Members** +                **Member statuses**                  {constants.Emojis.status_online} 2                  {constants.Emojis.status_idle} 1                  {constants.Emojis.status_dnd} 4 @@ -485,7 +488,7 @@ class UserEmbedTests(unittest.TestCase):          user.avatar_url_as.return_value = "avatar url"          embed = asyncio.run(self.cog.create_user_embed(ctx, user)) -        user.avatar_url_as.assert_called_once_with(format="png") +        user.avatar_url_as.assert_called_once_with(static_format="png")          self.assertEqual(embed.thumbnail.url, "avatar url") @@ -525,7 +528,7 @@ class UserCommandTests(unittest.TestCase):          ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))          msg = "Sorry, but you may only use this command within <#50>." -        with self.assertRaises(InChannelCheckFailure, msg=msg): +        with self.assertRaises(InWhitelistCheckFailure, msg=msg):              asyncio.run(self.cog.user_info.callback(self.cog, ctx))      @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py index 1dec0ccaf..cf9adbee0 100644 --- a/tests/bot/cogs/test_snekbox.py +++ b/tests/bot/cogs/test_snekbox.py @@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Post the eval code to the URLs.snekbox_eval_api endpoint."""          resp = MagicMock()          resp.json = AsyncMock(return_value="return") -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(await self.cog.post_eval("import random"), "return")          self.bot.http_session.post.assert_called_with( @@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          key = "MarkDiamond"          resp = MagicMock()          resp.json = AsyncMock(return_value={"key": key}) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          self.assertEqual(              await self.cog.upload_output("My awesome output"), @@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          """Output upload gracefully fallback if the upload fail."""          resp = MagicMock()          resp.json = AsyncMock(side_effect=Exception) -        self.bot.http_session.post().__aenter__.return_value = resp + +        context_manager = MagicMock() +        context_manager.__aenter__.return_value = resp +        self.bot.http_session.post.return_value = context_manager          log = logging.getLogger("bot.cogs.snekbox")          with self.assertLogs(logger=log, level='ERROR'): @@ -208,10 +217,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):      async def test_eval_command_call_help(self):          """Test if the eval command call the help command if no code is provided.""" -        ctx = MockContext() -        ctx.invoke = AsyncMock() +        ctx = MockContext(command="sentinel")          await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') -        ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") +        ctx.send_help.assert_called_once_with("sentinel")      async def test_send_eval(self):          """Test the send_eval function.""" @@ -291,7 +299,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(actual, expected)          self.bot.wait_for.assert_has_awaits(              ( -                call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), +                call( +                    'message_edit', +                    check=partial_mock(snekbox.predicate_eval_message_edit, ctx), +                    timeout=snekbox.REEVAL_TIMEOUT, +                ),                  call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)              )          ) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py index dae7c066c..f10d6fbe8 100644 --- a/tests/bot/test_constants.py +++ b/tests/bot/test_constants.py @@ -1,14 +1,40 @@  import inspect +import typing  import unittest  from bot import constants +def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool: +    """ +    Return True if `value` is an instance of the type represented by `annotation`. + +    This doesn't account for things like Unions or checking for homogenous types in collections. +    """ +    origin = typing.get_origin(annotation) + +    # This is done in case a bare e.g. `typing.List` is used. +    # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`. +    # `get_origin()` does this normalisation for us. +    type_ = annotation if origin is None else origin + +    return isinstance(value, type_) + + +def is_any_instance(value: typing.Any, types: typing.Collection) -> bool: +    """Return True if `value` is an instance of any type in `types`.""" +    for type_ in types: +        if is_annotation_instance(value, type_): +            return True + +    return False + +  class ConstantsTests(unittest.TestCase):      """Tests for our constants."""      def test_section_configuration_matches_type_specification(self): -        """The section annotations should match the actual types of the sections.""" +        """"The section annotations should match the actual types of the sections."""          sections = (              cls @@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):          )          for section in sections:              for name, annotation in section.__annotations__.items(): -                with self.subTest(section=section, name=name, annotation=annotation): +                with self.subTest(section=section.__name__, name=name, annotation=annotation):                      value = getattr(section, name) +                    origin = typing.get_origin(annotation) +                    annotation_args = typing.get_args(annotation) +                    failure_msg = f"{value} is not an instance of {annotation}" -                    if getattr(annotation, '_name', None) in ('Dict', 'List'): -                        self.skipTest("Cannot validate containers yet.") - -                    self.assertIsInstance(value, annotation) +                    if origin is typing.Union: +                        is_instance = is_any_instance(value, annotation_args) +                        self.assertTrue(is_instance, failure_msg) +                    else: +                        is_instance = is_annotation_instance(value, annotation) +                        self.assertTrue(is_instance, failure_msg) diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py new file mode 100644 index 000000000..3d450caa0 --- /dev/null +++ b/tests/bot/test_decorators.py @@ -0,0 +1,147 @@ +import collections +import unittest +import unittest.mock + +from bot import constants +from bot.decorators import in_whitelist +from bot.utils.checks import InWhitelistCheckFailure +from tests import helpers + +InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description")) + + +class InWhitelistTests(unittest.TestCase): +    """Tests for the `in_whitelist` check.""" + +    @classmethod +    def setUpClass(cls): +        """Set up helpers that only need to be defined once.""" +        cls.bot_commands = helpers.MockTextChannel(id=123456789, category_id=123456) +        cls.help_channel = helpers.MockTextChannel(id=987654321, category_id=987654) +        cls.non_whitelisted_channel = helpers.MockTextChannel(id=666666) +        cls.dm_channel = helpers.MockDMChannel() + +        cls.non_staff_member = helpers.MockMember() +        cls.staff_role = helpers.MockRole(id=121212) +        cls.staff_member = helpers.MockMember(roles=(cls.staff_role,)) + +        cls.channels = (cls.bot_commands.id,) +        cls.categories = (cls.help_channel.category_id,) +        cls.roles = (cls.staff_role.id,) + +    def test_predicate_returns_true_for_whitelisted_context(self): +        """The predicate should return `True` if a whitelisted context was passed to it.""" +        test_cases = ( +            InWhitelistTestCase( +                kwargs={"channels": self.channels}, +                ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), +                description="In whitelisted channels by members without whitelisted roles", +            ), +            InWhitelistTestCase( +                kwargs={"redirect": self.bot_commands.id}, +                ctx=helpers.MockContext(channel=self.bot_commands, author=self.non_staff_member), +                description="`redirect` should be implicitly added to `channels`", +            ), +            InWhitelistTestCase( +                kwargs={"categories": self.categories}, +                ctx=helpers.MockContext(channel=self.help_channel, author=self.non_staff_member), +                description="Whitelisted category without whitelisted role", +            ), +            InWhitelistTestCase( +                kwargs={"roles": self.roles}, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.staff_member), +                description="Whitelisted role outside of whitelisted channel/category" +            ), +            InWhitelistTestCase( +                kwargs={ +                    "channels": self.channels, +                    "categories": self.categories, +                    "roles": self.roles, +                    "redirect": self.bot_commands, +                }, +                ctx=helpers.MockContext(channel=self.help_channel, author=self.staff_member), +                description="Case with all whitelist kwargs used", +            ), +        ) + +        for test_case in test_cases: +            # patch `commands.check` with a no-op lambda that just returns the predicate passed to it +            # so we can test the predicate that was generated from the specified kwargs. +            with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): +                predicate = in_whitelist(**test_case.kwargs) + +            with self.subTest(test_description=test_case.description): +                self.assertTrue(predicate(test_case.ctx)) + +    def test_predicate_raises_exception_for_non_whitelisted_context(self): +        """The predicate should raise `InWhitelistCheckFailure` for a non-whitelisted context.""" +        test_cases = ( +            # Failing check with explicit `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": self.bot_commands.id, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check with an explicit redirect channel", +            ), + +            # Failing check with implicit `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check with an implicit redirect channel", +            ), + +            # Failing check without `redirect` +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": None, +                }, +                ctx=helpers.MockContext(channel=self.non_whitelisted_channel, author=self.non_staff_member), +                description="Failing check without a redirect channel", +            ), + +            # Command issued in DM channel +            InWhitelistTestCase( +                kwargs={ +                    "categories": self.categories, +                    "channels": self.channels, +                    "roles": self.roles, +                    "redirect": None, +                }, +                ctx=helpers.MockContext(channel=self.dm_channel, author=self.dm_channel.me), +                description="Commands issued in DM channel should be rejected", +            ), +        ) + +        for test_case in test_cases: +            if "redirect" not in test_case.kwargs or test_case.kwargs["redirect"] is not None: +                # There are two cases in which we have a redirect channel: +                #   1. No redirect channel was passed; the default value of `bot_commands` is used +                #   2. An explicit `redirect` is set that is "not None" +                redirect_channel = test_case.kwargs.get("redirect", constants.Channels.bot_commands) +                redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" +            else: +                # If an explicit `None` was passed for `redirect`, there is no redirect channel +                redirect_message = "" + +            exception_message = f"You are not allowed to use that command{redirect_message}." + +            # patch `commands.check` with a no-op lambda that just returns the predicate passed to it +            # so we can test the predicate that was generated from the specified kwargs. +            with unittest.mock.patch("bot.decorators.commands.check", new=lambda predicate: predicate): +                predicate = in_whitelist(**test_case.kwargs) + +            with self.subTest(test_description=test_case.description): +                with self.assertRaisesRegex(InWhitelistCheckFailure, exception_message): +                    predicate(test_case.ctx) diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py index 9610771e5..de72e5748 100644 --- a/tests/bot/utils/test_checks.py +++ b/tests/bot/utils/test_checks.py @@ -1,6 +1,8 @@  import unittest +from unittest.mock import MagicMock  from bot.utils import checks +from bot.utils.checks import InWhitelistCheckFailure  from tests.helpers import MockContext, MockRole @@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):          self.ctx.author.roles.append(MockRole(id=role_id))          self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) -    def test_in_channel_check_for_correct_channel(self): -        self.ctx.channel.id = 42 -        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_correct_channel(self): +        """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list.""" +        channel_id = 3 +        self.ctx.channel.id = channel_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id])) -    def test_in_channel_check_for_incorrect_channel(self): -        self.ctx.channel.id = 42 + 10 -        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) +    def test_in_whitelist_check_incorrect_channel(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match.""" +        self.ctx.channel.id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, [4]) + +    def test_in_whitelist_check_correct_category(self): +        """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list.""" +        category_id = 3 +        self.ctx.channel.category_id = category_id +        self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id])) + +    def test_in_whitelist_check_incorrect_category(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match.""" +        self.ctx.channel.category_id = 3 +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, categories=[4]) + +    def test_in_whitelist_check_correct_role(self): +        """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6])) + +    def test_in_whitelist_check_incorrect_role(self): +        """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match.""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        with self.assertRaises(InWhitelistCheckFailure): +            checks.in_whitelist_check(self.ctx, roles=[4]) + +    def test_in_whitelist_check_fail_silently(self): +        """`in_whitelist_check` test no exception raised if `fail_silently` is `True`""" +        self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True)) + +    def test_in_whitelist_check_complex(self): +        """`in_whitelist_check` test with multiple parameters""" +        self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2)) +        self.ctx.channel.category_id = 3 +        self.ctx.channel.id = 5 +        self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2])) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py new file mode 100644 index 000000000..8c1a40640 --- /dev/null +++ b/tests/bot/utils/test_redis_cache.py @@ -0,0 +1,273 @@ +import asyncio +import unittest + +import fakeredis.aioredis + +from bot.utils import RedisCache +from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError +from tests import helpers + + +class RedisCacheTests(unittest.IsolatedAsyncioTestCase): +    """Tests the RedisCache class from utils.redis_dict.py.""" + +    async def asyncSetUp(self):  # noqa: N802 +        """Sets up the objects that only have to be initialized once.""" +        self.bot = helpers.MockBot() +        self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() + +        # Okay, so this is necessary so that we can create a clean new +        # class for every test method, and we want that because it will +        # ensure we get a fresh loop, which is necessary for test_increment_lock +        # to be able to pass. +        class DummyCog: +            """A dummy cog, for dummies.""" + +            redis = RedisCache() + +            def __init__(self, bot: helpers.MockBot): +                self.bot = bot + +        self.cog = DummyCog(self.bot) + +        await self.cog.redis.clear() + +    def test_class_attribute_namespace(self): +        """Test that RedisDict creates a namespace automatically for class attributes.""" +        self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") + +    async def test_class_attribute_required(self): +        """Test that errors are raised when not assigned as a class attribute.""" +        bad_cache = RedisCache() +        self.assertIs(bad_cache._namespace, None) + +        with self.assertRaises(RuntimeError): +            await bad_cache.set("test", "me_up_deadman") + +    def test_namespace_collision(self): +        """Test that we prevent colliding namespaces.""" +        bob_cache_1 = RedisCache() +        bob_cache_1._set_namespace("BobRoss") +        self.assertEqual(bob_cache_1._namespace, "BobRoss") + +        bob_cache_2 = RedisCache() +        bob_cache_2._set_namespace("BobRoss") +        self.assertEqual(bob_cache_2._namespace, "BobRoss_") + +    async def test_set_get_item(self): +        """Test that users can set and get items from the RedisDict.""" +        test_cases = ( +            ('favorite_fruit', 'melon'), +            ('favorite_number', 86), +            ('favorite_fraction', 86.54) +        ) + +        # Test that we can get and set different types. +        for test in test_cases: +            await self.cog.redis.set(*test) +            self.assertEqual(await self.cog.redis.get(test[0]), test[1]) + +        # Test that .get allows a default value +        self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") + +    async def test_set_item_type(self): +        """Test that .set rejects keys and values that are not permitted.""" +        fruits = ["lemon", "melon", "apple"] + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(fruits, "nice") + +        with self.assertRaises(TypeError): +            await self.cog.redis.set(4.23, "nice") + +    async def test_delete_item(self): +        """Test that .delete allows us to delete stuff from the RedisCache.""" +        # Add an item and verify that it gets added +        await self.cog.redis.set("internet", "firetruck") +        self.assertEqual(await self.cog.redis.get("internet"), "firetruck") + +        # Delete that item and verify that it gets deleted +        await self.cog.redis.delete("internet") +        self.assertIs(await self.cog.redis.get("internet"), None) + +    async def test_contains(self): +        """Test that we can check membership with .contains.""" +        await self.cog.redis.set('favorite_country', "Burkina Faso") + +        self.assertIs(await self.cog.redis.contains('favorite_country'), True) +        self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) + +    async def test_items(self): +        """Test that the RedisDict can be iterated.""" +        # Set up our test cases in the Redis cache +        test_cases = [ +            ('favorite_turtle', 'Donatello'), +            ('second_favorite_turtle', 'Leonardo'), +            ('third_favorite_turtle', 'Raphael'), +        ] +        for key, value in test_cases: +            await self.cog.redis.set(key, value) + +        # Consume the AsyncIterator into a regular list, easier to compare that way. +        redis_items = [item for item in await self.cog.redis.items()] + +        # These sequences are probably in the same order now, but probably +        # isn't good enough for tests. Let's not rely on .hgetall always +        # returning things in sequence, and just sort both lists to be safe. +        redis_items = sorted(redis_items) +        test_cases = sorted(test_cases) + +        # If these are equal now, everything works fine. +        self.assertSequenceEqual(test_cases, redis_items) + +    async def test_length(self): +        """Test that we can get the correct .length from the RedisDict.""" +        await self.cog.redis.set('one', 1) +        await self.cog.redis.set('two', 2) +        await self.cog.redis.set('three', 3) +        self.assertEqual(await self.cog.redis.length(), 3) + +        await self.cog.redis.set('four', 4) +        self.assertEqual(await self.cog.redis.length(), 4) + +    async def test_to_dict(self): +        """Test that the .to_dict method returns a workable dictionary copy.""" +        copy = await self.cog.redis.to_dict() +        local_copy = {key: value for key, value in await self.cog.redis.items()} +        self.assertIs(type(copy), dict) +        self.assertDictEqual(copy, local_copy) + +    async def test_clear(self): +        """Test that the .clear method removes the entire hash.""" +        await self.cog.redis.set('teddy', 'with me') +        await self.cog.redis.set('in my dreams', 'you have a weird hat') +        self.assertEqual(await self.cog.redis.length(), 2) + +        await self.cog.redis.clear() +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_pop(self): +        """Test that we can .pop an item from the RedisDict.""" +        await self.cog.redis.set('john', 'was afraid') + +        self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') +        self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') +        self.assertEqual(await self.cog.redis.length(), 0) + +    async def test_update(self): +        """Test that we can .update the RedisDict with multiple items.""" +        await self.cog.redis.set("reckfried", "lona") +        await self.cog.redis.set("bel air", "prince") +        await self.cog.redis.update({ +            "reckfried": "jona", +            "mega": "hungry, though", +        }) + +        result = { +            "reckfried": "jona", +            "bel air": "prince", +            "mega": "hungry, though", +        } +        self.assertDictEqual(await self.cog.redis.to_dict(), result) + +    def test_typestring_conversion(self): +        """Test the typestring-related helper functions.""" +        conversion_tests = ( +            (12, "i|12"), +            (12.4, "f|12.4"), +            ("cowabunga", "s|cowabunga"), +        ) + +        # Test conversion to typestring +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) + +        # Test conversion from typestrings +        for _input, expected in conversion_tests: +            self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) + +        # Test that exceptions are raised on invalid input +        with self.assertRaises(TypeError): +            self.cog.redis._value_to_typestring(["internet"]) +            self.cog.redis._value_from_typestring("o|firedog") + +    async def test_increment_decrement(self): +        """Test .increment and .decrement methods.""" +        await self.cog.redis.set("entropic", 5) +        await self.cog.redis.set("disentropic", 12.5) + +        # Test default increment +        await self.cog.redis.increment("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 6) + +        # Test default decrement +        await self.cog.redis.decrement("entropic") +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # Test float increment with float +        await self.cog.redis.increment("disentropic", 2.0) +        self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) + +        # Test float increment with int +        await self.cog.redis.increment("disentropic", 2) +        self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) + +        # Test negative increments, because why not. +        await self.cog.redis.increment("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 0) + +        # Negative decrements? Sure. +        await self.cog.redis.decrement("entropic", -5) +        self.assertEqual(await self.cog.redis.get("entropic"), 5) + +        # What about if we use a negative float to decrement an int? +        # This should convert the type into a float. +        await self.cog.redis.decrement("entropic", -2.5) +        self.assertEqual(await self.cog.redis.get("entropic"), 7.5) + +        # Let's test that they raise the right errors +        with self.assertRaises(KeyError): +            await self.cog.redis.increment("doesn't_exist!") + +        await self.cog.redis.set("stringthing", "stringthing") +        with self.assertRaises(TypeError): +            await self.cog.redis.increment("stringthing") + +    async def test_increment_lock(self): +        """Test that we can't produce a race condition in .increment.""" +        await self.cog.redis.set("test_key", 0) +        tasks = [] + +        # Increment this a lot in different tasks +        for _ in range(100): +            task = asyncio.create_task( +                self.cog.redis.increment("test_key", 1) +            ) +            tasks.append(task) +        await asyncio.gather(*tasks) + +        # Confirm that the value has been incremented the exact right number of times. +        value = await self.cog.redis.get("test_key") +        self.assertEqual(value, 100) + +    async def test_exceptions_raised(self): +        """Testing that the various RuntimeErrors are reachable.""" +        class MyCog: +            cache = RedisCache() + +            def __init__(self): +                self.other_cache = RedisCache() + +        cog = MyCog() + +        # Raises "No Bot instance" +        with self.assertRaises(NoBotInstanceError): +            await cog.cache.get("john") + +        # Raises "RedisCache has no namespace" +        with self.assertRaises(NoNamespaceError): +            await cog.other_cache.get("was") + +        # Raises "You must access the RedisCache instance through the cog instance" +        with self.assertRaises(NoParentInstanceError): +            await MyCog.cache.get("afraid") diff --git a/tests/helpers.py b/tests/helpers.py index 8e13f0f28..faa839370 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,12 +4,15 @@ import collections  import itertools  import logging  import unittest.mock +from asyncio import AbstractEventLoop  from typing import Iterable, Optional  import discord +from aiohttp import ClientSession  from discord.ext.commands import Context  from bot.api import APIClient +from bot.async_stats import AsyncStatsClient  from bot.bot import Bot @@ -205,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):          """Simplified position-based comparisons similar to those of `discord.Role`."""          return self.position < other.position +    def __ge__(self, other): +        """Simplified position-based comparisons similar to those of `discord.Role`.""" +        return self.position >= other.position +  # Create a Member instance to get a realistic Mock of `discord.Member`  member_data = {'user': 'lemon', 'roles': [1]} @@ -264,10 +271,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):      spec_set = APIClient -# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` -bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) -bot_instance.http_session = None -bot_instance.api_client = None +def _get_mock_loop() -> unittest.mock.Mock: +    """Return a mocked asyncio.AbstractEventLoop.""" +    loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True) + +    # Since calling `create_task` on our MockBot does not actually schedule the coroutine object +    # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object +    # to prevent "has not been awaited"-warnings. +    loop.create_task.side_effect = lambda coroutine: coroutine.close() + +    return loop  class MockBot(CustomMockMixin, unittest.mock.MagicMock): @@ -277,17 +290,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):      Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.      For more information, see the `MockGuild` docstring.      """ -    spec_set = bot_instance -    additional_spec_asyncs = ("wait_for",) +    spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) +    additional_spec_asyncs = ("wait_for", "redis_ready")      def __init__(self, **kwargs) -> None:          super().__init__(**kwargs) -        self.api_client = MockAPIClient() -        # Since calling `create_task` on our MockBot does not actually schedule the coroutine object -        # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object -        # to prevent "has not been awaited"-warnings. -        self.loop.create_task.side_effect = lambda coroutine: coroutine.close() +        self.loop = _get_mock_loop() +        self.api_client = MockAPIClient(loop=self.loop) +        self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True) +        self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)  # Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` @@ -315,7 +327,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):      """      spec_set = channel_instance -    def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: +    def __init__(self, **kwargs) -> None:          default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}          super().__init__(**collections.ChainMap(kwargs, default_kwargs)) @@ -323,6 +335,27 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):              self.mention = f"#{self.name}" +# Create data for the DMChannel instance +state = unittest.mock.MagicMock() +me = unittest.mock.MagicMock() +dm_channel_data = {"id": 1, "recipients": [unittest.mock.MagicMock()]} +dm_channel_instance = discord.DMChannel(me=me, state=state, data=dm_channel_data) + + +class MockDMChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock TextChannel objects. + +    Instances of this class will follow the specifications of `discord.TextChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ +    spec_set = dm_channel_instance + +    def __init__(self, **kwargs) -> None: +        default_kwargs = {'id': next(self.discord_id), 'recipient': MockUser(), "me": MockUser()} +        super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + +  # Create a Message instance to get a realistic MagicMock of `discord.Message`  message_data = {      'id': 1, | 
