diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bot/cogs/test_antimalware.py | 48 | 
1 files changed, 19 insertions, 29 deletions
| diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index eba439afb..6fb7b399e 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,4 +1,3 @@ -import asyncio  import logging  import unittest  from os.path import splitext @@ -13,7 +12,7 @@ from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole  MODULE = "bot.cogs.antimalware" -class AntiMalwareCogTests(unittest.TestCase): +class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):      """Test the AntiMalware cog."""      def setUp(self): @@ -22,62 +21,56 @@ class AntiMalwareCogTests(unittest.TestCase):          self.cog = antimalware.AntiMalware(self.bot)          self.message = MockMessage() -    def test_message_with_allowed_attachment(self): +    async def test_message_with_allowed_attachment(self):          """Messages with allowed extensions should not be deleted"""          attachment = MockAttachment(filename=f"python.{AntiMalwareConfig.whitelist[0]}")          self.message.attachments = [attachment] -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          self.message.delete.assert_not_called() -    def test_message_without_attachment(self): +    async def test_message_without_attachment(self):          """Messages without attachments should result in no action.""" -        coroutine = self.cog.on_message(self.message) -        self.assertIsNone(asyncio.run(coroutine)) +        self.assertIsNone(await self.cog.on_message(self.message))          self.message.delete.assert_not_called() -    def test_direct_message_with_attachment(self): +    async def test_direct_message_with_attachment(self):          """Direct messages should have no action taken."""          attachment = MockAttachment(filename="python.asdfsff")          self.message.attachments = [attachment]          self.message.guild = None -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          self.message.delete.assert_not_called() -    def test_message_with_illegal_extension_gets_deleted(self): +    async def test_message_with_illegal_extension_gets_deleted(self):          """A message containing an illegal extension should send an embed."""          attachment = MockAttachment(filename="python.asdfsff")          self.message.attachments = [attachment] -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          self.message.delete.assert_called_once() -    def test_message_send_by_staff(self): +    async def test_message_send_by_staff(self):          """A message send by a member of staff should be ignored."""          moderator_role = MockRole(name="Moderator", id=Roles.moderators)          self.message.author.roles.append(moderator_role)          attachment = MockAttachment(filename="python.asdfsff")          self.message.attachments = [attachment] -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          self.message.delete.assert_not_called() -    def test_python_file_redirect_embed(self): +    async def test_python_file_redirect_embed(self):          """A message containing a .python file should result in an embed redirecting the user to our paste site"""          attachment = MockAttachment(filename="python.py")          self.message.attachments = [attachment]          self.message.channel.send = AsyncMock() -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          args, kwargs = self.message.channel.send.call_args          embed = kwargs.pop("embed") @@ -87,13 +80,12 @@ class AntiMalwareCogTests(unittest.TestCase):              f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}"          )) -    def test_txt_file_redirect_embed(self): +    async def test_txt_file_redirect_embed(self):          attachment = MockAttachment(filename="python.txt")          self.message.attachments = [attachment]          self.message.channel.send = AsyncMock() -        coroutine = self.cog.on_message(self.message) -        asyncio.run(coroutine) +        await self.cog.on_message(self.message)          args, kwargs = self.message.channel.send.call_args          embed = kwargs.pop("embed")          cmd_channel = self.bot.get_channel(Channels.bot_commands) @@ -109,34 +101,32 @@ class AntiMalwareCogTests(unittest.TestCase):              f"\n\n{URLs.site_schema}{URLs.site_paste}"          )) -    def test_removing_deleted_message_logs(self): +    async def test_removing_deleted_message_logs(self):          """Removing an already deleted message logs the correct message"""          attachment = MockAttachment(filename="python.asdfsff")          self.message.attachments = [attachment]          self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message="")) -        coroutine = self.cog.on_message(self.message)          logger = logging.getLogger(MODULE)          with self.assertLogs(logger=logger, level="INFO") as logs: -            asyncio.run(coroutine) +            await self.cog.on_message(self.message)          self.assertIn(              f"INFO:{MODULE}:Tried to delete message `{self.message.id}`, but message could not be found.",              logs.output) -    def test_message_with_illegal_attachment_logs(self): +    async def test_message_with_illegal_attachment_logs(self):          """Deleting a message with an illegal attachment should result in a log."""          attachment = MockAttachment(filename="python.asdfsff")          self.message.attachments = [attachment] -        coroutine = self.cog.on_message(self.message)          file_extensions = {splitext(attachment.filename.lower())[1] for attachment in self.message.attachments}          extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist)          blocked_extensions_str = ', '.join(extensions_blocked)          logger = logging.getLogger(MODULE)          with self.assertLogs(logger=logger, level="INFO") as logs: -            asyncio.run(coroutine) +            await self.cog.on_message(self.message)          self.assertEqual(              [                  f"INFO:{MODULE}:" | 
