diff options
| author | 2019-10-25 10:12:23 -0700 | |
|---|---|---|
| committer | 2019-10-25 10:12:23 -0700 | |
| commit | e68e9ef9cc6d6670a1c6b6a712fe87be1f33d60b (patch) | |
| tree | cccbc64ce0b9056efb7b2c07cbec9826d77400ff /tests/bot | |
| parent | Remove bold tag when no channel is available (diff) | |
| parent | Merge pull request #501 from mathsman5133/reddit-makeover (diff) | |
Merge branch 'master' into compact_free
Diffstat (limited to '')
20 files changed, 1166 insertions, 3 deletions
| diff --git a/tests/cogs/__init__.py b/tests/bot/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/cogs/__init__.py +++ b/tests/bot/__init__.py diff --git a/tests/cogs/sync/__init__.py b/tests/bot/cogs/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/cogs/sync/__init__.py +++ b/tests/bot/cogs/__init__.py diff --git a/tests/rules/__init__.py b/tests/bot/cogs/sync/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/rules/__init__.py +++ b/tests/bot/cogs/sync/__init__.py diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py new file mode 100644 index 000000000..27ae27639 --- /dev/null +++ b/tests/bot/cogs/sync/test_roles.py @@ -0,0 +1,126 @@ +import unittest + +from bot.cogs.sync.syncers import Role, get_roles_for_sync + + +class GetRolesForSyncTests(unittest.TestCase): +    """Tests constructing the roles to synchronize with the site.""" + +    def test_get_roles_for_sync_empty_return_for_equal_roles(self): +        """No roles should be synced when no diff is found.""" +        api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} +        guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            (set(), set(), set()) +        ) + +    def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): +        """Roles to be synced are returned when non-ID attributes differ.""" +        api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} +        guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            (set(), guild_roles, set()) +        ) + +    def test_get_roles_only_returns_roles_that_require_update(self): +        """Roles that require an update should be returned as the second tuple element.""" +        api_roles = { +            Role(id=41, name='old name', colour=33, permissions=0x8, position=1), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } +        guild_roles = { +            Role(id=41, name='new name', colour=35, permissions=0x8, position=2), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                set(), +                {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, +                set(), +            ) +        ) + +    def test_get_roles_returns_new_roles_in_first_tuple_element(self): +        """Newly created roles are returned as the first tuple element.""" +        api_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +        } +        guild_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +            Role(id=53, name='other role', colour=55, permissions=0, position=2) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, +                set(), +                set(), +            ) +        ) + +    def test_get_roles_returns_roles_to_update_and_new_roles(self): +        """Newly created and updated roles should be returned together.""" +        api_roles = { +            Role(id=41, name='old name', colour=35, permissions=0x8, position=1), +        } +        guild_roles = { +            Role(id=41, name='new name', colour=40, permissions=0x16, position=2), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, +                {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, +                set(), +            ) +        ) + +    def test_get_roles_returns_roles_to_delete(self): +        """Roles to be deleted should be returned as the third tuple element.""" +        api_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), +        } +        guild_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                set(), +                set(), +                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, +            ) +        ) + +    def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): +        """When roles were added, updated, and removed, all of them are returned properly.""" +        api_roles = { +            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), +            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), +            Role(id=71, name='to update', colour=99, permissions=0x9, position=3), +        } +        guild_roles = { +            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), +            Role(id=81, name='to create', colour=99, permissions=0x9, position=4), +            Role(id=71, name='updated', colour=101, permissions=0x5, position=3), +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, +                {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, +                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, +            ) +        ) diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py new file mode 100644 index 000000000..ccaf67490 --- /dev/null +++ b/tests/bot/cogs/sync/test_users.py @@ -0,0 +1,84 @@ +import unittest + +from bot.cogs.sync.syncers import User, get_users_for_sync + + +def fake_user(**kwargs): +    kwargs.setdefault('id', 43) +    kwargs.setdefault('name', 'bob the test man') +    kwargs.setdefault('discriminator', 1337) +    kwargs.setdefault('avatar_hash', None) +    kwargs.setdefault('roles', (666,)) +    kwargs.setdefault('in_guild', True) +    return User(**kwargs) + + +class GetUsersForSyncTests(unittest.TestCase): +    """Tests constructing the users to synchronize with the site.""" + +    def test_get_users_for_sync_returns_nothing_for_empty_params(self): +        """When no users are given, none are returned.""" +        self.assertEqual( +            get_users_for_sync({}, {}), +            (set(), set()) +        ) + +    def test_get_users_for_sync_returns_nothing_for_equal_users(self): +        """When no users are updated, none are returned.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user()} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), set()) +        ) + +    def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): +        """When a non-ID-field differs, the user to update is returned.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user(name='new fancy name')} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), {fake_user(name='new fancy name')}) +        ) + +    def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): +        """When new users join the guild, they are returned as the first tuple element.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user(), 63: fake_user(id=63)} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            ({fake_user(id=63)}, set()) +        ) + +    def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): +        """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" +        api_users = {43: fake_user(), 63: fake_user(id=63)} +        guild_users = {43: fake_user()} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), {fake_user(id=63, in_guild=False)}) +        ) + +    def test_get_users_for_sync_updates_and_creates_users_as_needed(self): +        """When one user left and another one was updated, both are returned.""" +        api_users = {43: fake_user()} +        guild_users = {63: fake_user(id=63)} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            ({fake_user(id=63)}, {fake_user(in_guild=False)}) +        ) + +    def test_get_users_for_sync_does_not_duplicate_update_users(self): +        """When the API knows a user the guild doesn't, nothing is performed.""" +        api_users = {43: fake_user(in_guild=False)} +        guild_users = {} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), set()) +        ) diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py new file mode 100644 index 000000000..ce5472c71 --- /dev/null +++ b/tests/bot/cogs/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.cogs import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): +    """Tests validation of the antispam cog configuration.""" + +    def test_default_antispam_config_is_valid(self): +        """The default antispam configuration is valid.""" +        validation_errors = antispam.validate_config() +        self.assertEqual(validation_errors, {}) + +    def test_unknown_rule_returns_error(self): +        """Configuring an unknown rule returns an error.""" +        self.assertEqual( +            antispam.validate_config({'invalid-rule': {}}), +            {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} +        ) + +    def test_missing_keys_returns_error(self): +        """Not configuring required keys returns an error.""" +        keys = (('interval', 'max'), ('max', 'interval')) +        for configured_key, unconfigured_key in keys: +            with self.subTest( +                configured_key=configured_key, +                unconfigured_key=unconfigured_key +            ): +                config = {'burst': {configured_key: 10}} +                error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + +                self.assertEqual( +                    antispam.validate_config(config), +                    {'burst': error} +                ) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py new file mode 100644 index 000000000..9bbd35a91 --- /dev/null +++ b/tests/bot/cogs/test_information.py @@ -0,0 +1,164 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.cogs import information +from tests.helpers import AsyncMock, MockBot, MockContext, MockGuild, MockMember, MockRole + + +class InformationCogTests(unittest.TestCase): +    """Tests the Information cog.""" + +    @classmethod +    def setUpClass(cls): +        cls.moderator_role = MockRole(name="Moderator", role_id=constants.Roles.moderator) + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() + +        self.cog = information.Information(self.bot) + +        self.ctx = MockContext() +        self.ctx.author.roles.append(self.moderator_role) + +    def test_roles_command_command(self): +        """Test if the `role_info` command correctly returns the `moderator_role`.""" +        self.ctx.guild.roles.append(self.moderator_role) + +        self.cog.roles_info.can_run = AsyncMock() +        self.cog.roles_info.can_run.return_value = True + +        coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + +        self.assertIsNone(asyncio.run(coroutine)) +        self.ctx.send.assert_called_once() + +        _, kwargs = self.ctx.send.call_args +        embed = kwargs.pop('embed') + +        self.assertEqual(embed.title, "Role information") +        self.assertEqual(embed.colour, discord.Colour.blurple()) +        self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n") +        self.assertEqual(embed.footer.text, "Total roles: 1") + +    def test_role_info_command(self): +        """Tests the `role info` command.""" +        dummy_role = MockRole( +            name="Dummy", +            role_id=112233445566778899, +            colour=discord.Colour.blurple(), +            position=10, +            members=[self.ctx.author], +            permissions=discord.Permissions(0) +        ) + +        admin_role = MockRole( +            name="Admins", +            role_id=998877665544332211, +            colour=discord.Colour.red(), +            position=3, +            members=[self.ctx.author], +            permissions=discord.Permissions(0), +        ) + +        self.ctx.guild.roles.append([dummy_role, admin_role]) + +        self.cog.role_info.can_run = AsyncMock() +        self.cog.role_info.can_run.return_value = True + +        coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + +        self.assertIsNone(asyncio.run(coroutine)) + +        self.assertEqual(self.ctx.send.call_count, 2) + +        (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + +        dummy_embed = dummy_kwargs["embed"] +        admin_embed = admin_kwargs["embed"] + +        self.assertEqual(dummy_embed.title, "Dummy info") +        self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + +        self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) +        self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") +        self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") +        self.assertEqual(dummy_embed.fields[3].value, "1") +        self.assertEqual(dummy_embed.fields[4].value, "10") +        self.assertEqual(dummy_embed.fields[5].value, "0") + +        self.assertEqual(admin_embed.title, "Admins info") +        self.assertEqual(admin_embed.colour, discord.Colour.red()) + +    @unittest.mock.patch('bot.cogs.information.time_since') +    def test_server_info_command(self, time_since_patch): +        time_since_patch.return_value = '2 days ago' + +        self.ctx.guild = MockGuild( +            features=('lemons', 'apples'), +            region="The Moon", +            roles=[self.moderator_role], +            channels=[ +                discord.TextChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} +                ), +                discord.CategoryChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} +                ), +                discord.VoiceChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} +                ) +            ], +            members=[ +                *(MockMember(status='online') for _ in range(2)), +                *(MockMember(status='idle') for _ in range(1)), +                *(MockMember(status='dnd') for _ in range(4)), +                *(MockMember(status='offline') for _ in range(3)), +            ], +            member_count=1_234, +            icon_url='a-lemon.jpg', +        ) + +        coroutine = self.cog.server_info.callback(self.cog, self.ctx) +        self.assertIsNone(asyncio.run(coroutine)) + +        time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') +        _, kwargs = self.ctx.send.call_args +        embed = kwargs.pop('embed') +        self.assertEqual(embed.colour, discord.Colour.blurple()) +        self.assertEqual( +            embed.description, +            textwrap.dedent( +                f""" +                **Server information** +                Created: {time_since_patch.return_value} +                Voice region: {self.ctx.guild.region} +                Features: {', '.join(self.ctx.guild.features)} + +                **Counts** +                Members: {self.ctx.guild.member_count:,} +                Roles: {len(self.ctx.guild.roles)} +                Text: 1 +                Voice: 1 +                Channel categories: 1 + +                **Members** +                {constants.Emojis.status_online} 2 +                {constants.Emojis.status_idle} 1 +                {constants.Emojis.status_dnd} 4 +                {constants.Emojis.status_offline} 3 +                """ +            ) +        ) +        self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py new file mode 100644 index 000000000..efa7a50b1 --- /dev/null +++ b/tests/bot/cogs/test_security.py @@ -0,0 +1,59 @@ +import logging +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.cogs import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): +    """Tests the `Security` cog.""" + +    def setUp(self): +        """Attach an instance of the cog to the class for tests.""" +        self.bot = MockBot() +        self.cog = security.Security(self.bot) +        self.ctx = MockContext() + +    def test_check_additions(self): +        """The cog should add its checks after initialization.""" +        self.bot.check.assert_any_call(self.cog.check_on_guild) +        self.bot.check.assert_any_call(self.cog.check_not_bot) + +    def test_check_not_bot_returns_false_for_humans(self): +        """The bot check should return `True` when invoked with human authors.""" +        self.ctx.author.bot = False +        self.assertTrue(self.cog.check_not_bot(self.ctx)) + +    def test_check_not_bot_returns_true_for_robots(self): +        """The bot check should return `False` when invoked with robotic authors.""" +        self.ctx.author.bot = True +        self.assertFalse(self.cog.check_not_bot(self.ctx)) + +    def test_check_on_guild_raises_when_outside_of_guild(self): +        """When invoked outside of a guild, `check_on_guild` should cause an error.""" +        self.ctx.guild = None + +        with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): +            self.cog.check_on_guild(self.ctx) + +    def test_check_on_guild_returns_true_inside_of_guild(self): +        """When invoked inside of a guild, `check_on_guild` should return `True`.""" +        self.ctx.guild = "lemon's lemonade stand" +        self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): +    """Tests loading the `Security` cog.""" + +    def test_security_cog_load(self): +        """Cog loading logs a message at `INFO` level.""" +        bot = MagicMock() +        with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: +            security.setup(bot) +            bot.add_cog.assert_called_once() + +        [line] = cm.output +        self.assertIn("Cog loaded: Security", line) diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py new file mode 100644 index 000000000..dfb1bafc9 --- /dev/null +++ b/tests/bot/cogs/test_token_remover.py @@ -0,0 +1,135 @@ +import asyncio +import logging +import unittest +from unittest.mock import MagicMock + +from discord import Colour + +from bot.cogs.token_remover import ( +    DELETION_MESSAGE_TEMPLATE, +    TokenRemover, +    setup as setup_cog, +) +from bot.constants import Channels, Colours, Event, Icons +from tests.helpers import AsyncMock, MockBot, MockMessage + + +class TokenRemoverTests(unittest.TestCase): +    """Tests the `TokenRemover` cog.""" + +    def setUp(self): +        """Adds the cog, a bot, and a message to the instance for usage in tests.""" +        self.bot = MockBot() +        self.bot.get_cog.return_value = MagicMock() +        self.bot.get_cog.return_value.send_log_message = AsyncMock() +        self.cog = TokenRemover(bot=self.bot) + +        self.msg = MockMessage(message_id=555, content='') +        self.msg.author.__str__ = MagicMock() +        self.msg.author.__str__.return_value = 'lemon' +        self.msg.author.bot = False +        self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' +        self.msg.author.id = 42 +        self.msg.author.mention = '@lemon' +        self.msg.channel.mention = "#lemonade-stand" + +    def test_is_valid_user_id_is_true_for_numeric_content(self): +        """A string decoding to numeric characters is a valid user ID.""" +        # MTIz = base64(123) +        self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) + +    def test_is_valid_user_id_is_false_for_alphabetic_content(self): +        """A string decoding to alphabetic characters is not a valid user ID.""" +        # YWJj = base64(abc) +        self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) + +    def test_is_valid_timestamp_is_true_for_valid_timestamps(self): +        """A string decoding to a valid timestamp should be recognized as such.""" +        self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) + +    def test_is_valid_timestamp_is_false_for_invalid_values(self): +        """A string not decoding to a valid timestamp should not be recognized as such.""" +        # MTIz = base64(123) +        self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) + +    def test_mod_log_property(self): +        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" +        self.bot.get_cog.return_value = 'lemon' +        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) +        self.bot.get_cog.assert_called_once_with('ModLog') + +    def test_ignores_bot_messages(self): +        """When the message event handler is called with a bot message, nothing is done.""" +        self.msg.author.bot = True +        coroutine = self.cog.on_message(self.msg) +        self.assertIsNone(asyncio.run(coroutine)) + +    def test_ignores_messages_without_tokens(self): +        """Messages without anything looking like a token are ignored.""" +        for content in ('', 'lemon wins'): +            with self.subTest(content=content): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                self.assertIsNone(asyncio.run(coroutine)) + +    def test_ignores_messages_with_invalid_tokens(self): +        """Messages with values that are invalid tokens are ignored.""" +        for content in ('foo.bar.baz', 'x.y.'): +            with self.subTest(content=content): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                self.assertIsNone(asyncio.run(coroutine)) + +    def test_censors_valid_tokens(self): +        """Valid tokens are censored.""" +        cases = ( +            # (content, censored_token) +            ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), +        ) + +        for content, censored_token in cases: +            with self.subTest(content=content, censored_token=censored_token): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: +                    self.assertIsNone(asyncio.run(coroutine))  # no return value + +                [line] = cm.output +                log_message = ( +                    "Censored a seemingly valid token sent by " +                    "lemon (`42`) in #lemonade-stand, " +                    f"token was `{censored_token}`" +                ) +                self.assertIn(log_message, line) + +                self.msg.delete.assert_called_once_with() +                self.msg.channel.send.assert_called_once_with( +                    DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') +                ) +                self.bot.get_cog.assert_called_with('ModLog') +                self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') + +                mod_log = self.bot.get_cog.return_value +                mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) +                mod_log.send_log_message.assert_called_once_with( +                    icon_url=Icons.token_removed, +                    colour=Colour(Colours.soft_red), +                    title="Token removed!", +                    text=log_message, +                    thumbnail='picture-lemon.png', +                    channel_id=Channels.mod_alerts +                ) + + +class TokenRemoverSetupTests(unittest.TestCase): +    """Tests setup of the `TokenRemover` cog.""" + +    def test_setup(self): +        """Setup of the cog should log a message at `INFO` level.""" +        bot = MockBot() +        with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: +            setup_cog(bot) + +        [line] = cm.output +        bot.add_cog.assert_called_once() +        self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/utils/__init__.py b/tests/bot/patches/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/utils/__init__.py +++ b/tests/bot/patches/__init__.py diff --git a/tests/bot/resources/__init__.py b/tests/bot/resources/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/resources/__init__.py diff --git a/tests/bot/resources/test_resources.py b/tests/bot/resources/test_resources.py new file mode 100644 index 000000000..73937cfa6 --- /dev/null +++ b/tests/bot/resources/test_resources.py @@ -0,0 +1,17 @@ +import json +import unittest +from pathlib import Path + + +class ResourceValidationTests(unittest.TestCase): +    """Validates resources used by the bot.""" +    def test_stars_valid(self): +        """The resource `bot/resources/stars.json` should contain a list of strings.""" +        path = Path('bot', 'resources', 'stars.json') +        content = path.read_text() +        data = json.loads(content) + +        self.assertIsInstance(data, list) +        for name in data: +            with self.subTest(name=name): +                self.assertIsInstance(name, str) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/rules/__init__.py diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py new file mode 100644 index 000000000..4bb0acf7c --- /dev/null +++ b/tests/bot/rules/test_attachments.py @@ -0,0 +1,52 @@ +import asyncio +import unittest +from dataclasses import dataclass +from typing import Any, List + +from bot.rules import attachments + + +# Using `MagicMock` sadly doesn't work for this usecase +# since it's __eq__ compares the MagicMock's ID. We just +# want to compare the actual attributes we set. +@dataclass +class FakeMessage: +    author: str +    attachments: List[Any] + + +def msg(total_attachments: int) -> FakeMessage: +    return FakeMessage(author='lemon', attachments=list(range(total_attachments))) + + +class AttachmentRuleTests(unittest.TestCase): +    """Tests applying the `attachment` antispam rule.""" + +    def test_allows_messages_without_too_many_attachments(self): +        """Messages without too many attachments are allowed as-is.""" +        cases = ( +            (msg(0), msg(0), msg(0)), +            (msg(2), msg(2)), +            (msg(0),), +        ) + +        for last_message, *recent_messages in cases: +            with self.subTest(last_message=last_message, recent_messages=recent_messages): +                coro = attachments.apply(last_message, recent_messages, {'max': 5}) +                self.assertIsNone(asyncio.run(coro)) + +    def test_disallows_messages_with_too_many_attachments(self): +        """Messages with too many attachments trigger the rule.""" +        cases = ( +            ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), +            ((msg(6),), [msg(6)], 6), +            ((msg(1),) * 6, [msg(1)] * 6, 6), +        ) +        for messages, relevant_messages, total in cases: +            with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total): +                last_message, *recent_messages = messages +                coro = attachments.apply(last_message, recent_messages, {'max': 5}) +                self.assertEqual( +                    asyncio.run(coro), +                    (f"sent {total} attachments in 5s", ('lemon',), relevant_messages) +                ) diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py new file mode 100644 index 000000000..e0ede0eb1 --- /dev/null +++ b/tests/bot/test_api.py @@ -0,0 +1,134 @@ +import logging +import unittest +from unittest.mock import MagicMock, patch + +from bot import api +from tests.base import LoggingTestCase +from tests.helpers import async_test + + +class APIClientTests(unittest.TestCase): +    """Tests for the bot's API client.""" + +    @classmethod +    def setUpClass(cls): +        """Sets up the shared fixtures for the tests.""" +        cls.error_api_response = MagicMock() +        cls.error_api_response.status = 999 + +    def test_loop_is_not_running_by_default(self): +        """The event loop should not be running by default.""" +        self.assertFalse(api.loop_is_running()) + +    @async_test +    async def test_loop_is_running_in_async_context(self): +        """The event loop should be running in an async context.""" +        self.assertTrue(api.loop_is_running()) + +    def test_response_code_error_default_initialization(self): +        """Test the default initialization of `ResponseCodeError` without `text` or `json`""" +        error = api.ResponseCodeError(response=self.error_api_response) + +        self.assertIs(error.status, self.error_api_response.status) +        self.assertEqual(error.response_json, {}) +        self.assertEqual(error.response_text, "") +        self.assertIs(error.response, self.error_api_response) + +    def test_responde_code_error_string_representation_default_initialization(self): +        """Test the string representation of `ResponseCodeError` initialized without text or json.""" +        error = api.ResponseCodeError(response=self.error_api_response) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") + +    def test_response_code_error_initialization_with_json(self): +        """Test the initialization of `ResponseCodeError` with json.""" +        json_data = {'hello': 'world'} +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_json=json_data, +        ) +        self.assertEqual(error.response_json, json_data) +        self.assertEqual(error.response_text, "") + +    def test_response_code_error_string_representation_with_nonempty_response_json(self): +        """Test the string representation of `ResponseCodeError` initialized with json.""" +        json_data = {'hello': 'world'} +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_json=json_data +        ) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") + +    def test_response_code_error_initialization_with_text(self): +        """Test the initialization of `ResponseCodeError` with text.""" +        text_data = 'Lemon will eat your soul' +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_text=text_data, +        ) +        self.assertEqual(error.response_text, text_data) +        self.assertEqual(error.response_json, {}) + +    def test_response_code_error_string_representation_with_nonempty_response_text(self): +        """Test the string representation of `ResponseCodeError` initialized with text.""" +        text_data = 'Lemon will eat your soul' +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_text=text_data +        ) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") + + +class LoggingHandlerTests(LoggingTestCase): +    """Tests the bot's API Log Handler.""" + +    @classmethod +    def setUpClass(cls): +        cls.debug_log_record = logging.LogRecord( +            name='my.logger', level=logging.DEBUG, +            pathname='my/logger.py', lineno=666, +            msg="Lemon wins", args=(), +            exc_info=None +        ) + +        cls.trace_log_record = logging.LogRecord( +            name='my.logger', level=logging.TRACE, +            pathname='my/logger.py', lineno=666, +            msg="This will not be logged", args=(), +            exc_info=None +        ) + +    def setUp(self): +        self.log_handler = api.APILoggingHandler(None) + +    def test_emit_appends_to_queue_with_stopped_event_loop(self): +        """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" +        with patch("bot.api.APILoggingHandler.ship_off") as ship_off: +            # Patch `ship_off` to ease testing against the return value of this coroutine. +            ship_off.return_value = 42 +            self.log_handler.emit(self.debug_log_record) + +        self.assertListEqual(self.log_handler.queue, [42]) + +    def test_emit_ignores_less_than_debug(self): +        """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" +        self.log_handler.emit(self.trace_log_record) +        self.assertListEqual(self.log_handler.queue, []) + +    def test_schedule_queued_tasks_for_empty_queue(self): +        """`APILoggingHandler` should not schedule anything when the queue is empty.""" +        with self.assertNotLogs(level=logging.DEBUG): +            self.log_handler.schedule_queued_tasks() + +    def test_schedule_queued_tasks_for_nonempty_queue(self): +        """`APILoggingHandler` should schedule logs when the queue is not empty.""" +        with self.assertLogs(level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: +            self.log_handler.queue = [555] +            self.log_handler.schedule_queued_tasks() +            self.assertListEqual(self.log_handler.queue, []) +            create_task.assert_called_once_with(555) + +            [record] = logs.records +            self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") +            self.assertEqual(record.levelno, logging.DEBUG) +            self.assertEqual(record.name, 'bot.api') +            self.assertIn('via_handler', record.__dict__) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py new file mode 100644 index 000000000..dae7c066c --- /dev/null +++ b/tests/bot/test_constants.py @@ -0,0 +1,26 @@ +import inspect +import unittest + +from bot import constants + + +class ConstantsTests(unittest.TestCase): +    """Tests for our constants.""" + +    def test_section_configuration_matches_type_specification(self): +        """The section annotations should match the actual types of the sections.""" + +        sections = ( +            cls +            for (name, cls) in inspect.getmembers(constants) +            if hasattr(cls, 'section') and isinstance(cls, type) +        ) +        for section in sections: +            for name, annotation in section.__annotations__.items(): +                with self.subTest(section=section, name=name, annotation=annotation): +                    value = getattr(section, name) + +                    if getattr(annotation, '_name', None) in ('Dict', 'List'): +                        self.skipTest("Cannot validate containers yet.") + +                    self.assertIsInstance(value, annotation) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py new file mode 100644 index 000000000..b2b78d9dd --- /dev/null +++ b/tests/bot/test_converters.py @@ -0,0 +1,273 @@ +import asyncio +import datetime +import unittest +from unittest.mock import MagicMock, patch + +from dateutil.relativedelta import relativedelta +from discord.ext.commands import BadArgument + +from bot.converters import ( +    Duration, +    ISODateTime, +    TagContentConverter, +    TagNameConverter, +    ValidPythonIdentifier, +) + + +class ConverterTests(unittest.TestCase): +    """Tests our custom argument converters.""" + +    @classmethod +    def setUpClass(cls): +        cls.context = MagicMock +        cls.context.author = 'bob' + +        cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') + +    def test_tag_content_converter_for_valid(self): +        """TagContentConverter should return correct values for valid input.""" +        test_values = ( +            ('hello', 'hello'), +            ('  h ello  ', 'h ello'), +        ) + +        for content, expected_conversion in test_values: +            with self.subTest(content=content, expected_conversion=expected_conversion): +                conversion = asyncio.run(TagContentConverter.convert(self.context, content)) +                self.assertEqual(conversion, expected_conversion) + +    def test_tag_content_converter_for_invalid(self): +        """TagContentConverter should raise the proper exception for invalid input.""" +        test_values = ( +            ('', "Tag contents should not be empty, or filled with whitespace."), +            ('   ', "Tag contents should not be empty, or filled with whitespace."), +        ) + +        for value, exception_message in test_values: +            with self.subTest(tag_content=value, exception_message=exception_message): +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(TagContentConverter.convert(self.context, value)) + +    def test_tag_name_converter_for_valid(self): +        """TagNameConverter should return the correct values for valid tag names.""" +        test_values = ( +            ('tracebacks', 'tracebacks'), +            ('Tracebacks', 'tracebacks'), +            ('  Tracebacks  ', 'tracebacks'), +        ) + +        for name, expected_conversion in test_values: +            with self.subTest(name=name, expected_conversion=expected_conversion): +                conversion = asyncio.run(TagNameConverter.convert(self.context, name)) +                self.assertEqual(conversion, expected_conversion) + +    def test_tag_name_converter_for_invalid(self): +        """TagNameConverter should raise the correct exception for invalid tag names.""" +        test_values = ( +            ('👋', "Don't be ridiculous, you can't use that character!"), +            ('', "Tag names should not be empty, or filled with whitespace."), +            ('  ', "Tag names should not be empty, or filled with whitespace."), +            ('42', "Tag names can't be numbers."), +            ('x' * 128, "Are you insane? That's way too long!"), +        ) + +        for invalid_name, exception_message in test_values: +            with self.subTest(invalid_name=invalid_name, exception_message=exception_message): +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(TagNameConverter.convert(self.context, invalid_name)) + +    def test_valid_python_identifier_for_valid(self): +        """ValidPythonIdentifier returns valid identifiers unchanged.""" +        test_values = ('foo', 'lemon') + +        for name in test_values: +            with self.subTest(identifier=name): +                conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                self.assertEqual(name, conversion) + +    def test_valid_python_identifier_for_invalid(self): +        """ValidPythonIdentifier raises the proper exception for invalid identifiers.""" +        test_values = ('nested.stuff', '#####') + +        for name in test_values: +            with self.subTest(identifier=name): +                exception_message = f'`{name}` is not a valid Python identifier' +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(ValidPythonIdentifier.convert(self.context, name)) + +    def test_duration_converter_for_valid(self): +        """Duration returns the correct `datetime` for valid duration strings.""" +        test_values = ( +            # Simple duration strings +            ('1Y', {"years": 1}), +            ('1y', {"years": 1}), +            ('1year', {"years": 1}), +            ('1years', {"years": 1}), +            ('1m', {"months": 1}), +            ('1month', {"months": 1}), +            ('1months', {"months": 1}), +            ('1w', {"weeks": 1}), +            ('1W', {"weeks": 1}), +            ('1week', {"weeks": 1}), +            ('1weeks', {"weeks": 1}), +            ('1d', {"days": 1}), +            ('1D', {"days": 1}), +            ('1day', {"days": 1}), +            ('1days', {"days": 1}), +            ('1h', {"hours": 1}), +            ('1H', {"hours": 1}), +            ('1hour', {"hours": 1}), +            ('1hours', {"hours": 1}), +            ('1M', {"minutes": 1}), +            ('1minute', {"minutes": 1}), +            ('1minutes', {"minutes": 1}), +            ('1s', {"seconds": 1}), +            ('1S', {"seconds": 1}), +            ('1second', {"seconds": 1}), +            ('1seconds', {"seconds": 1}), + +            # Complex duration strings +            ( +                '1y1m1w1d1H1M1S', +                { +                    "years": 1, +                    "months": 1, +                    "weeks": 1, +                    "days": 1, +                    "hours": 1, +                    "minutes": 1, +                    "seconds": 1 +                } +            ), +            ('5y100S', {"years": 5, "seconds": 100}), +            ('2w28H', {"weeks": 2, "hours": 28}), + +            # Duration strings with spaces +            ('1 year 2 months', {"years": 1, "months": 2}), +            ('1d 2H', {"days": 1, "hours": 2}), +            ('1 week2 days', {"weeks": 1, "days": 2}), +        ) + +        converter = Duration() + +        for duration, duration_dict in test_values: +            expected_datetime = self.fixed_utc_now + relativedelta(**duration_dict) + +            with patch('bot.converters.datetime') as mock_datetime: +                mock_datetime.utcnow.return_value = self.fixed_utc_now + +                with self.subTest(duration=duration, duration_dict=duration_dict): +                    converted_datetime = asyncio.run(converter.convert(self.context, duration)) +                    self.assertEqual(converted_datetime, expected_datetime) + +    def test_duration_converter_for_invalid(self): +        """Duration raises the right exception for invalid duration strings.""" +        test_values = ( +            # Units in wrong order +            ('1d1w'), +            ('1s1y'), + +            # Duplicated units +            ('1 year 2 years'), +            ('1 M 10 minutes'), + +            # Unknown substrings +            ('1MVes'), +            ('1y3breads'), + +            # Missing amount +            ('ym'), + +            # Incorrect whitespace +            (" 1y"), +            ("1S "), +            ("1y  1m"), + +            # Garbage +            ('Guido van Rossum'), +            ('lemon lemon lemon lemon lemon lemon lemon'), +        ) + +        converter = Duration() + +        for invalid_duration in test_values: +            with self.subTest(invalid_duration=invalid_duration): +                exception_message = f'`{invalid_duration}` is not a valid duration string.' +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(converter.convert(self.context, invalid_duration)) + +    def test_isodatetime_converter_for_valid(self): +        """ISODateTime converter returns correct datetime for valid datetime string.""" +        test_values = ( +            # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` +            ('2019-09-02T02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM` +            ('2019-09-02T03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM` +            ('2019-09-02T03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH` +            ('2019-09-02 03:03:05+01', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T01:03:05-01', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS` +            ('2019-09-02T02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM` +            ('2019-11-12T09:15', datetime.datetime(2019, 11, 12, 9, 15)), +            ('2019-11-12 09:15', datetime.datetime(2019, 11, 12, 9, 15)), + +            # `YYYY-mm-dd` +            ('2019-04-01', datetime.datetime(2019, 4, 1)), + +            # `YYYY-mm` +            ('2019-02-01', datetime.datetime(2019, 2, 1)), + +            # `YYYY` +            ('2025', datetime.datetime(2025, 1, 1)), +        ) + +        converter = ISODateTime() + +        for datetime_string, expected_dt in test_values: +            with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt): +                converted_dt = asyncio.run(converter.convert(self.context, datetime_string)) +                self.assertIsNone(converted_dt.tzinfo) +                self.assertEqual(converted_dt, expected_dt) + +    def test_isodatetime_converter_for_invalid(self): +        """ISODateTime converter raises the correct exception for invalid datetime strings.""" +        test_values = ( +            # Make sure it doesn't interfere with the Duration converter +            ('1Y'), +            ('1d'), +            ('1H'), + +            # Check if it fails when only providing the optional time part +            ('10:10:10'), +            ('10:00'), + +            # Invalid date format +            ('19-01-01'), + +            # Other non-valid strings +            ('fisk the tag master'), +        ) + +        converter = ISODateTime() +        for datetime_string in test_values: +            with self.subTest(datetime_string=datetime_string): +                exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(converter.convert(self.context, datetime_string)) diff --git a/tests/test_pagination.py b/tests/bot/test_pagination.py index 11d6541ae..0a734b505 100644 --- a/tests/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -1,28 +1,35 @@  from unittest import TestCase -import pytest -  from bot import pagination  class LinePaginatorTests(TestCase): +    """Tests functionality of the `LinePaginator`.""" +      def setUp(self): +        """Create a paginator for the test method."""          self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30)      def test_add_line_raises_on_too_long_lines(self): +        """`add_line` should raise a `RuntimeError` for too long lines."""          message = f"Line exceeds maximum page size {self.paginator.max_size - 2}" -        with pytest.raises(RuntimeError, match=message): +        with self.assertRaises(RuntimeError, msg=message):              self.paginator.add_line('x' * self.paginator.max_size)      def test_add_line_works_on_small_lines(self): +        """`add_line` should allow small lines to be added."""          self.paginator.add_line('x' * (self.paginator.max_size - 3))  class ImagePaginatorTests(TestCase): +    """Tests functionality of the `ImagePaginator`.""" +      def setUp(self): +        """Create a paginator for the test method."""          self.paginator = pagination.ImagePaginator()      def test_add_image_appends_image(self): +        """`add_image` appends the image to the image list."""          image = 'lemon'          self.paginator.add_image(image) diff --git a/tests/bot/utils/__init__.py b/tests/bot/utils/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/utils/__init__.py diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py new file mode 100644 index 000000000..19b758336 --- /dev/null +++ b/tests/bot/utils/test_checks.py @@ -0,0 +1,51 @@ +import unittest + +from bot.utils import checks +from tests.helpers import MockContext, MockRole + + +class ChecksTests(unittest.TestCase): +    """Tests the check functions defined in `bot.checks`.""" + +    def setUp(self): +        self.ctx = MockContext() + +    def test_with_role_check_without_guild(self): +        """`with_role_check` returns `False` if `Context.guild` is None.""" +        self.ctx.guild = None +        self.assertFalse(checks.with_role_check(self.ctx)) + +    def test_with_role_check_without_required_roles(self): +        """`with_role_check` returns `False` if `Context.author` lacks the required role.""" +        self.ctx.author.roles = [] +        self.assertFalse(checks.with_role_check(self.ctx)) + +    def test_with_role_check_with_guild_and_required_role(self): +        """`with_role_check` returns `True` if `Context.author` has the required role.""" +        self.ctx.author.roles.append(MockRole(role_id=10)) +        self.assertTrue(checks.with_role_check(self.ctx, 10)) + +    def test_without_role_check_without_guild(self): +        """`without_role_check` should return `False` when `Context.guild` is None.""" +        self.ctx.guild = None +        self.assertFalse(checks.without_role_check(self.ctx)) + +    def test_without_role_check_returns_false_with_unwanted_role(self): +        """`without_role_check` returns `False` if `Context.author` has unwanted role.""" +        role_id = 42 +        self.ctx.author.roles.append(MockRole(role_id=role_id)) +        self.assertFalse(checks.without_role_check(self.ctx, role_id)) + +    def test_without_role_check_returns_true_without_unwanted_role(self): +        """`without_role_check` returns `True` if `Context.author` does not have unwanted role.""" +        role_id = 42 +        self.ctx.author.roles.append(MockRole(role_id=role_id)) +        self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) + +    def test_in_channel_check_for_correct_channel(self): +        self.ctx.channel.id = 42 +        self.assertTrue(checks.in_channel_check(self.ctx, *[42])) + +    def test_in_channel_check_for_incorrect_channel(self): +        self.ctx.channel.id = 42 + 10 +        self.assertFalse(checks.in_channel_check(self.ctx, *[42])) | 
