aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar ks129 <[email protected]>2020-06-11 18:49:00 +0300
committerGravatar GitHub <[email protected]>2020-06-11 18:49:00 +0300
commit49ba7b98b424f9f4199f5f57b4237481954fa06c (patch)
treeef17e82e4f018381432828285780aa453eac54e3 /tests
parentMod Utils Tests: Replace `has_active_infraction` with `get_active_infraction` (diff)
parentFix trailing whitespace in Action file (diff)
Merge branch 'master' into mod-utils-tests
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/cogs/moderation/test_infractions.py55
-rw-r--r--tests/bot/cogs/moderation/test_modlog.py29
-rw-r--r--tests/bot/cogs/sync/test_cog.py3
-rw-r--r--tests/bot/cogs/sync/test_users.py2
-rw-r--r--tests/bot/cogs/test_antimalware.py159
-rw-r--r--tests/bot/cogs/test_duck_pond.py2
-rw-r--r--tests/bot/cogs/test_information.py15
-rw-r--r--tests/bot/cogs/test_snekbox.py26
-rw-r--r--tests/bot/test_constants.py43
-rw-r--r--tests/bot/test_converters.py113
-rw-r--r--tests/bot/test_decorators.py4
-rw-r--r--tests/bot/utils/test_checks.py52
-rw-r--r--tests/bot/utils/test_redis_cache.py273
-rw-r--r--tests/helpers.py34
14 files changed, 715 insertions, 95 deletions
diff --git a/tests/bot/cogs/moderation/test_infractions.py b/tests/bot/cogs/moderation/test_infractions.py
new file mode 100644
index 000000000..da4e92ccc
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_infractions.py
@@ -0,0 +1,55 @@
+import textwrap
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from bot.cogs.moderation.infractions import Infractions
+from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole
+
+
+class TruncationTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for ban and kick command reason truncation."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = Infractions(self.bot)
+ self.user = MockMember(id=1234, top_role=MockRole(id=3577, position=10))
+ self.target = MockMember(id=1265, top_role=MockRole(id=9876, position=0))
+ self.guild = MockGuild(id=4567)
+ self.ctx = MockContext(bot=self.bot, author=self.user, guild=self.guild)
+
+ @patch("bot.cogs.moderation.utils.get_active_infraction")
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_ban_reason_truncation(self, post_infraction_mock, get_active_mock):
+ """Should truncate reason for `ctx.guild.ban`."""
+ get_active_mock.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.bot.get_cog.return_value = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.ctx.guild.ban = Mock()
+
+ await self.cog.apply_ban(self.ctx, self.target, "foo bar" * 3000)
+ self.ctx.guild.ban.assert_called_once_with(
+ self.target,
+ reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."),
+ delete_message_days=0
+ )
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.ctx.guild.ban.return_value
+ )
+
+ @patch("bot.cogs.moderation.utils.post_infraction")
+ async def test_apply_kick_reason_truncation(self, post_infraction_mock):
+ """Should truncate reason for `Member.kick`."""
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.cog.apply_infraction = AsyncMock()
+ self.cog.mod_log.ignore = Mock()
+ self.target.kick = Mock()
+
+ await self.cog.apply_kick(self.ctx, self.target, "foo bar" * 3000)
+ self.target.kick.assert_called_once_with(reason=textwrap.shorten("foo bar" * 3000, 512, placeholder="..."))
+ self.cog.apply_infraction.assert_awaited_once_with(
+ self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
+ )
diff --git a/tests/bot/cogs/moderation/test_modlog.py b/tests/bot/cogs/moderation/test_modlog.py
new file mode 100644
index 000000000..f2809f40a
--- /dev/null
+++ b/tests/bot/cogs/moderation/test_modlog.py
@@ -0,0 +1,29 @@
+import unittest
+
+import discord
+
+from bot.cogs.moderation.modlog import ModLog
+from tests.helpers import MockBot, MockTextChannel
+
+
+class ModLogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for moderation logs."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = ModLog(self.bot)
+ self.channel = MockTextChannel()
+
+ async def test_log_entry_description_truncation(self):
+ """Test that embed description for ModLog entry is truncated."""
+ self.bot.get_channel.return_value = self.channel
+ await self.cog.send_log_message(
+ icon_url="foo",
+ colour=discord.Colour.blue(),
+ title="bar",
+ text="foo bar" * 3000
+ )
+ embed = self.channel.send.call_args[1]["embed"]
+ self.assertEqual(
+ embed.description, ("foo bar" * 3000)[:2045] + "..."
+ )
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
index 81398c61f..14fd909c4 100644
--- a/tests/bot/cogs/sync/test_cog.py
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -247,14 +247,12 @@ class SyncCogListenerTests(SyncCogTestCase):
before_data = {
"name": "old name",
"discriminator": "1234",
- "avatar": "old avatar",
"bot": False,
}
subtests = (
(True, "name", "name", "new name", "new name"),
(True, "discriminator", "discriminator", "8765", 8765),
- (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
(False, "bot", "bot", True, True),
)
@@ -295,7 +293,6 @@ class SyncCogListenerTests(SyncCogTestCase):
)
data = {
- "avatar_hash": member.avatar,
"discriminator": int(member.discriminator),
"id": member.id,
"in_guild": True,
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index 818883012..002a947ad 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -10,7 +10,6 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("avatar_hash", None)
kwargs.setdefault("roles", (666,))
kwargs.setdefault("in_guild", True)
@@ -32,7 +31,6 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
for member in members:
member = member.copy()
- member["avatar"] = member.pop("avatar_hash")
del member["in_guild"]
mock_member = helpers.MockMember(**member)
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
new file mode 100644
index 000000000..f219fc1ba
--- /dev/null
+++ b/tests/bot/cogs/test_antimalware.py
@@ -0,0 +1,159 @@
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from discord import NotFound
+
+from bot.cogs import antimalware
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
+
+MODULE = "bot.cogs.antimalware"
+
+
+@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
+class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
+ """Test the AntiMalware cog."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = MockBot()
+ self.cog = antimalware.AntiMalware(self.bot)
+ self.message = MockMessage()
+
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should not be deleted"""
+ attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_message_without_attachment(self):
+ """Messages without attachments should result in no action."""
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_direct_message_with_attachment(self):
+ """Direct messages should have no action taken."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.guild = None
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_message_with_illegal_extension_gets_deleted(self):
+ """A message containing an illegal extension should send an embed."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_called_once()
+
+ async def test_message_send_by_staff(self):
+ """A message send by a member of staff should be ignored."""
+ staff_role = MockRole(id=STAFF_ROLES[0])
+ self.message.author.roles.append(staff_role)
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site"""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+
+ self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
+
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt file should result in the correct embed."""
+ attachment = MockAttachment(filename="python.txt")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+
+ async def test_other_disallowed_extention_embed_description(self):
+ """Test the description for a non .py/.txt disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ meta_channel = self.bot.get_channel(Channels.meta)
+
+ self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+
+ async def test_removing_deleted_message_logs(self):
+ """Removing an already deleted message logs the correct message"""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_called_once()
+
+ async def test_message_with_illegal_attachment_logs(self):
+ """Deleting a message with an illegal attachment should result in a log."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (AntiMalwareConfig.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], [".disallowed"]),
+ ([".disallowed"], [".disallowed"]),
+ ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
+ disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
+
+
+class AntiMalwareSetupTests(unittest.TestCase):
+ """Tests setup of the `AntiMalware` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ antimalware.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 7e6bfc748..a8c0107c6 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
self.assertEqual(cog.bot, bot)
self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond)
- bot.loop.create_loop.called_once_with(cog.fetch_webhook())
+ bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())
def test_fetch_webhook_succeeds_without_connectivity_issues(self):
"""The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute."""
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index b5f928dd6..79c0e0ad3 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,10 +7,9 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InWhitelistCheckFailure
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
COG_PATH = "bot.cogs.information.Information"
@@ -149,14 +148,18 @@ class InformationCogTests(unittest.TestCase):
Voice region: {self.ctx.guild.region}
Features: {', '.join(self.ctx.guild.features)}
- **Counts**
- Members: {self.ctx.guild.member_count:,}
- Roles: {len(self.ctx.guild.roles)}
+ **Channel counts**
Category channels: 1
Text channels: 1
Voice channels: 1
+ Staff channels: 0
+
+ **Member counts**
+ Members: {self.ctx.guild.member_count:,}
+ Staff members: 0
+ Roles: {len(self.ctx.guild.roles)}
- **Members**
+ **Member statuses**
{constants.Emojis.status_online} 2
{constants.Emojis.status_idle} 1
{constants.Emojis.status_dnd} 4
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
index 1dec0ccaf..cf9adbee0 100644
--- a/tests/bot/cogs/test_snekbox.py
+++ b/tests/bot/cogs/test_snekbox.py
@@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
resp = MagicMock()
resp.json = AsyncMock(return_value="return")
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
self.assertEqual(await self.cog.post_eval("import random"), "return")
self.bot.http_session.post.assert_called_with(
@@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
key = "MarkDiamond"
resp = MagicMock()
resp.json = AsyncMock(return_value={"key": key})
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
self.assertEqual(
await self.cog.upload_output("My awesome output"),
@@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Output upload gracefully fallback if the upload fail."""
resp = MagicMock()
resp.json = AsyncMock(side_effect=Exception)
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
log = logging.getLogger("bot.cogs.snekbox")
with self.assertLogs(logger=log, level='ERROR'):
@@ -208,10 +217,9 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_eval_command_call_help(self):
"""Test if the eval command call the help command if no code is provided."""
- ctx = MockContext()
- ctx.invoke = AsyncMock()
+ ctx = MockContext(command="sentinel")
await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
- ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval")
+ ctx.send_help.assert_called_once_with("sentinel")
async def test_send_eval(self):
"""Test the send_eval function."""
@@ -291,7 +299,11 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(actual, expected)
self.bot.wait_for.assert_has_awaits(
(
- call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10),
+ call(
+ 'message_edit',
+ check=partial_mock(snekbox.predicate_eval_message_edit, ctx),
+ timeout=snekbox.REEVAL_TIMEOUT,
+ ),
call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
)
)
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
index dae7c066c..f10d6fbe8 100644
--- a/tests/bot/test_constants.py
+++ b/tests/bot/test_constants.py
@@ -1,14 +1,40 @@
import inspect
+import typing
import unittest
from bot import constants
+def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool:
+ """
+ Return True if `value` is an instance of the type represented by `annotation`.
+
+ This doesn't account for things like Unions or checking for homogenous types in collections.
+ """
+ origin = typing.get_origin(annotation)
+
+ # This is done in case a bare e.g. `typing.List` is used.
+ # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`.
+ # `get_origin()` does this normalisation for us.
+ type_ = annotation if origin is None else origin
+
+ return isinstance(value, type_)
+
+
+def is_any_instance(value: typing.Any, types: typing.Collection) -> bool:
+ """Return True if `value` is an instance of any type in `types`."""
+ for type_ in types:
+ if is_annotation_instance(value, type_):
+ return True
+
+ return False
+
+
class ConstantsTests(unittest.TestCase):
"""Tests for our constants."""
def test_section_configuration_matches_type_specification(self):
- """The section annotations should match the actual types of the sections."""
+ """"The section annotations should match the actual types of the sections."""
sections = (
cls
@@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):
)
for section in sections:
for name, annotation in section.__annotations__.items():
- with self.subTest(section=section, name=name, annotation=annotation):
+ with self.subTest(section=section.__name__, name=name, annotation=annotation):
value = getattr(section, name)
+ origin = typing.get_origin(annotation)
+ annotation_args = typing.get_args(annotation)
+ failure_msg = f"{value} is not an instance of {annotation}"
- if getattr(annotation, '_name', None) in ('Dict', 'List'):
- self.skipTest("Cannot validate containers yet.")
-
- self.assertIsInstance(value, annotation)
+ if origin is typing.Union:
+ is_instance = is_any_instance(value, annotation_args)
+ self.assertTrue(is_instance, failure_msg)
+ else:
+ is_instance = is_annotation_instance(value, annotation)
+ self.assertTrue(is_instance, failure_msg)
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index ca8cb6825..c42111f3f 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -1,5 +1,5 @@
-import asyncio
import datetime
+import re
import unittest
from unittest.mock import MagicMock, patch
@@ -16,7 +16,7 @@ from bot.converters import (
)
-class ConverterTests(unittest.TestCase):
+class ConverterTests(unittest.IsolatedAsyncioTestCase):
"""Tests our custom argument converters."""
@classmethod
@@ -26,7 +26,7 @@ class ConverterTests(unittest.TestCase):
cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
- def test_tag_content_converter_for_valid(self):
+ async def test_tag_content_converter_for_valid(self):
"""TagContentConverter should return correct values for valid input."""
test_values = (
('hello', 'hello'),
@@ -35,10 +35,10 @@ class ConverterTests(unittest.TestCase):
for content, expected_conversion in test_values:
with self.subTest(content=content, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagContentConverter.convert(self.context, content))
+ conversion = await TagContentConverter.convert(self.context, content)
self.assertEqual(conversion, expected_conversion)
- def test_tag_content_converter_for_invalid(self):
+ async def test_tag_content_converter_for_invalid(self):
"""TagContentConverter should raise the proper exception for invalid input."""
test_values = (
('', "Tag contents should not be empty, or filled with whitespace."),
@@ -47,10 +47,10 @@ class ConverterTests(unittest.TestCase):
for value, exception_message in test_values:
with self.subTest(tag_content=value, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagContentConverter.convert(self.context, value))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagContentConverter.convert(self.context, value)
- def test_tag_name_converter_for_valid(self):
+ async def test_tag_name_converter_for_valid(self):
"""TagNameConverter should return the correct values for valid tag names."""
test_values = (
('tracebacks', 'tracebacks'),
@@ -60,10 +60,10 @@ class ConverterTests(unittest.TestCase):
for name, expected_conversion in test_values:
with self.subTest(name=name, expected_conversion=expected_conversion):
- conversion = asyncio.run(TagNameConverter.convert(self.context, name))
+ conversion = await TagNameConverter.convert(self.context, name)
self.assertEqual(conversion, expected_conversion)
- def test_tag_name_converter_for_invalid(self):
+ async def test_tag_name_converter_for_invalid(self):
"""TagNameConverter should raise the correct exception for invalid tag names."""
test_values = (
('👋', "Don't be ridiculous, you can't use that character!"),
@@ -75,29 +75,29 @@ class ConverterTests(unittest.TestCase):
for invalid_name, exception_message in test_values:
with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(TagNameConverter.convert(self.context, invalid_name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await TagNameConverter.convert(self.context, invalid_name)
- def test_valid_python_identifier_for_valid(self):
+ async def test_valid_python_identifier_for_valid(self):
"""ValidPythonIdentifier returns valid identifiers unchanged."""
test_values = ('foo', 'lemon')
for name in test_values:
with self.subTest(identifier=name):
- conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ conversion = await ValidPythonIdentifier.convert(self.context, name)
self.assertEqual(name, conversion)
- def test_valid_python_identifier_for_invalid(self):
+ async def test_valid_python_identifier_for_invalid(self):
"""ValidPythonIdentifier raises the proper exception for invalid identifiers."""
test_values = ('nested.stuff', '#####')
for name in test_values:
with self.subTest(identifier=name):
exception_message = f'`{name}` is not a valid Python identifier'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await ValidPythonIdentifier.convert(self.context, name)
- def test_duration_converter_for_valid(self):
+ async def test_duration_converter_for_valid(self):
"""Duration returns the correct `datetime` for valid duration strings."""
test_values = (
# Simple duration strings
@@ -159,35 +159,35 @@ class ConverterTests(unittest.TestCase):
mock_datetime.utcnow.return_value = self.fixed_utc_now
with self.subTest(duration=duration, duration_dict=duration_dict):
- converted_datetime = asyncio.run(converter.convert(self.context, duration))
+ converted_datetime = await converter.convert(self.context, duration)
self.assertEqual(converted_datetime, expected_datetime)
- def test_duration_converter_for_invalid(self):
+ async def test_duration_converter_for_invalid(self):
"""Duration raises the right exception for invalid duration strings."""
test_values = (
# Units in wrong order
- ('1d1w'),
- ('1s1y'),
+ '1d1w',
+ '1s1y',
# Duplicated units
- ('1 year 2 years'),
- ('1 M 10 minutes'),
+ '1 year 2 years',
+ '1 M 10 minutes',
# Unknown substrings
- ('1MVes'),
- ('1y3breads'),
+ '1MVes',
+ '1y3breads',
# Missing amount
- ('ym'),
+ 'ym',
# Incorrect whitespace
- (" 1y"),
- ("1S "),
- ("1y 1m"),
+ " 1y",
+ "1S ",
+ "1y 1m",
# Garbage
- ('Guido van Rossum'),
- ('lemon lemon lemon lemon lemon lemon lemon'),
+ 'Guido van Rossum',
+ 'lemon lemon lemon lemon lemon lemon lemon',
)
converter = Duration()
@@ -195,10 +195,21 @@ class ConverterTests(unittest.TestCase):
for invalid_duration in test_values:
with self.subTest(invalid_duration=invalid_duration):
exception_message = f'`{invalid_duration}` is not a valid duration string.'
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, invalid_duration))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_duration)
- def test_isodatetime_converter_for_valid(self):
+ @patch("bot.converters.datetime")
+ async def test_duration_converter_out_of_range(self, mock_datetime):
+ """Duration converter should raise BadArgument if datetime raises a ValueError."""
+ mock_datetime.__add__.side_effect = ValueError
+ mock_datetime.utcnow.return_value = mock_datetime
+
+ duration = f"{datetime.MAXYEAR}y"
+ exception_message = f"`{duration}` results in a datetime outside the supported range."
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await Duration().convert(self.context, duration)
+
+ async def test_isodatetime_converter_for_valid(self):
"""ISODateTime converter returns correct datetime for valid datetime string."""
test_values = (
# `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
@@ -243,37 +254,37 @@ class ConverterTests(unittest.TestCase):
for datetime_string, expected_dt in test_values:
with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt):
- converted_dt = asyncio.run(converter.convert(self.context, datetime_string))
+ converted_dt = await converter.convert(self.context, datetime_string)
self.assertIsNone(converted_dt.tzinfo)
self.assertEqual(converted_dt, expected_dt)
- def test_isodatetime_converter_for_invalid(self):
+ async def test_isodatetime_converter_for_invalid(self):
"""ISODateTime converter raises the correct exception for invalid datetime strings."""
test_values = (
# Make sure it doesn't interfere with the Duration converter
- ('1Y'),
- ('1d'),
- ('1H'),
+ '1Y',
+ '1d',
+ '1H',
# Check if it fails when only providing the optional time part
- ('10:10:10'),
- ('10:00'),
+ '10:10:10',
+ '10:00',
# Invalid date format
- ('19-01-01'),
+ '19-01-01',
# Other non-valid strings
- ('fisk the tag master'),
+ 'fisk the tag master',
)
converter = ISODateTime()
for datetime_string in test_values:
with self.subTest(datetime_string=datetime_string):
exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string"
- with self.assertRaises(BadArgument, msg=exception_message):
- asyncio.run(converter.convert(self.context, datetime_string))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, datetime_string)
- def test_hush_duration_converter_for_valid(self):
+ async def test_hush_duration_converter_for_valid(self):
"""HushDurationConverter returns correct value for minutes duration or `"forever"` strings."""
test_values = (
("0", 0),
@@ -286,10 +297,10 @@ class ConverterTests(unittest.TestCase):
converter = HushDurationConverter()
for minutes_string, expected_minutes in test_values:
with self.subTest(minutes_string=minutes_string, expected_minutes=expected_minutes):
- converted = asyncio.run(converter.convert(self.context, minutes_string))
+ converted = await converter.convert(self.context, minutes_string)
self.assertEqual(expected_minutes, converted)
- def test_hush_duration_converter_for_invalid(self):
+ async def test_hush_duration_converter_for_invalid(self):
"""HushDurationConverter raises correct exception for invalid minutes duration strings."""
test_values = (
("16", "Duration must be at most 15 minutes."),
@@ -299,5 +310,5 @@ class ConverterTests(unittest.TestCase):
converter = HushDurationConverter()
for invalid_minutes_string, exception_message in test_values:
with self.subTest(invalid_minutes_string=invalid_minutes_string, exception_message=exception_message):
- with self.assertRaisesRegex(BadArgument, exception_message):
- asyncio.run(converter.convert(self.context, invalid_minutes_string))
+ with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
+ await converter.convert(self.context, invalid_minutes_string)
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
index a17dd3e16..3d450caa0 100644
--- a/tests/bot/test_decorators.py
+++ b/tests/bot/test_decorators.py
@@ -3,10 +3,10 @@ import unittest
import unittest.mock
from bot import constants
-from bot.decorators import InWhitelistCheckFailure, in_whitelist
+from bot.decorators import in_whitelist
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 9610771e5..de72e5748 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -1,6 +1,8 @@
import unittest
+from unittest.mock import MagicMock
from bot.utils import checks
+from bot.utils.checks import InWhitelistCheckFailure
from tests.helpers import MockContext, MockRole
@@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):
self.ctx.author.roles.append(MockRole(id=role_id))
self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
- def test_in_channel_check_for_correct_channel(self):
- self.ctx.channel.id = 42
- self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_correct_channel(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list."""
+ channel_id = 3
+ self.ctx.channel.id = channel_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id]))
- def test_in_channel_check_for_incorrect_channel(self):
- self.ctx.channel.id = 42 + 10
- self.assertFalse(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_incorrect_channel(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match."""
+ self.ctx.channel.id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, [4])
+
+ def test_in_whitelist_check_correct_category(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list."""
+ category_id = 3
+ self.ctx.channel.category_id = category_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id]))
+
+ def test_in_whitelist_check_incorrect_category(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match."""
+ self.ctx.channel.category_id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, categories=[4])
+
+ def test_in_whitelist_check_correct_role(self):
+ """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6]))
+
+ def test_in_whitelist_check_incorrect_role(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, roles=[4])
+
+ def test_in_whitelist_check_fail_silently(self):
+ """`in_whitelist_check` test no exception raised if `fail_silently` is `True`"""
+ self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True))
+
+ def test_in_whitelist_check_complex(self):
+ """`in_whitelist_check` test with multiple parameters"""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.ctx.channel.category_id = 3
+ self.ctx.channel.id = 5
+ self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2]))
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
new file mode 100644
index 000000000..8c1a40640
--- /dev/null
+++ b/tests/bot/utils/test_redis_cache.py
@@ -0,0 +1,273 @@
+import asyncio
+import unittest
+
+import fakeredis.aioredis
+
+from bot.utils import RedisCache
+from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError
+from tests import helpers
+
+
+class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the RedisCache class from utils.redis_dict.py."""
+
+ async def asyncSetUp(self): # noqa: N802
+ """Sets up the objects that only have to be initialized once."""
+ self.bot = helpers.MockBot()
+ self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
+
+ # Okay, so this is necessary so that we can create a clean new
+ # class for every test method, and we want that because it will
+ # ensure we get a fresh loop, which is necessary for test_increment_lock
+ # to be able to pass.
+ class DummyCog:
+ """A dummy cog, for dummies."""
+
+ redis = RedisCache()
+
+ def __init__(self, bot: helpers.MockBot):
+ self.bot = bot
+
+ self.cog = DummyCog(self.bot)
+
+ await self.cog.redis.clear()
+
+ def test_class_attribute_namespace(self):
+ """Test that RedisDict creates a namespace automatically for class attributes."""
+ self.assertEqual(self.cog.redis._namespace, "DummyCog.redis")
+
+ async def test_class_attribute_required(self):
+ """Test that errors are raised when not assigned as a class attribute."""
+ bad_cache = RedisCache()
+ self.assertIs(bad_cache._namespace, None)
+
+ with self.assertRaises(RuntimeError):
+ await bad_cache.set("test", "me_up_deadman")
+
+ def test_namespace_collision(self):
+ """Test that we prevent colliding namespaces."""
+ bob_cache_1 = RedisCache()
+ bob_cache_1._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_1._namespace, "BobRoss")
+
+ bob_cache_2 = RedisCache()
+ bob_cache_2._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_2._namespace, "BobRoss_")
+
+ async def test_set_get_item(self):
+ """Test that users can set and get items from the RedisDict."""
+ test_cases = (
+ ('favorite_fruit', 'melon'),
+ ('favorite_number', 86),
+ ('favorite_fraction', 86.54)
+ )
+
+ # Test that we can get and set different types.
+ for test in test_cases:
+ await self.cog.redis.set(*test)
+ self.assertEqual(await self.cog.redis.get(test[0]), test[1])
+
+ # Test that .get allows a default value
+ self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+
+ async def test_set_item_type(self):
+ """Test that .set rejects keys and values that are not permitted."""
+ fruits = ["lemon", "melon", "apple"]
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(fruits, "nice")
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(4.23, "nice")
+
+ async def test_delete_item(self):
+ """Test that .delete allows us to delete stuff from the RedisCache."""
+ # Add an item and verify that it gets added
+ await self.cog.redis.set("internet", "firetruck")
+ self.assertEqual(await self.cog.redis.get("internet"), "firetruck")
+
+ # Delete that item and verify that it gets deleted
+ await self.cog.redis.delete("internet")
+ self.assertIs(await self.cog.redis.get("internet"), None)
+
+ async def test_contains(self):
+ """Test that we can check membership with .contains."""
+ await self.cog.redis.set('favorite_country', "Burkina Faso")
+
+ self.assertIs(await self.cog.redis.contains('favorite_country'), True)
+ self.assertIs(await self.cog.redis.contains('favorite_dentist'), False)
+
+ async def test_items(self):
+ """Test that the RedisDict can be iterated."""
+ # Set up our test cases in the Redis cache
+ test_cases = [
+ ('favorite_turtle', 'Donatello'),
+ ('second_favorite_turtle', 'Leonardo'),
+ ('third_favorite_turtle', 'Raphael'),
+ ]
+ for key, value in test_cases:
+ await self.cog.redis.set(key, value)
+
+ # Consume the AsyncIterator into a regular list, easier to compare that way.
+ redis_items = [item for item in await self.cog.redis.items()]
+
+ # These sequences are probably in the same order now, but probably
+ # isn't good enough for tests. Let's not rely on .hgetall always
+ # returning things in sequence, and just sort both lists to be safe.
+ redis_items = sorted(redis_items)
+ test_cases = sorted(test_cases)
+
+ # If these are equal now, everything works fine.
+ self.assertSequenceEqual(test_cases, redis_items)
+
+ async def test_length(self):
+ """Test that we can get the correct .length from the RedisDict."""
+ await self.cog.redis.set('one', 1)
+ await self.cog.redis.set('two', 2)
+ await self.cog.redis.set('three', 3)
+ self.assertEqual(await self.cog.redis.length(), 3)
+
+ await self.cog.redis.set('four', 4)
+ self.assertEqual(await self.cog.redis.length(), 4)
+
+ async def test_to_dict(self):
+ """Test that the .to_dict method returns a workable dictionary copy."""
+ copy = await self.cog.redis.to_dict()
+ local_copy = {key: value for key, value in await self.cog.redis.items()}
+ self.assertIs(type(copy), dict)
+ self.assertDictEqual(copy, local_copy)
+
+ async def test_clear(self):
+ """Test that the .clear method removes the entire hash."""
+ await self.cog.redis.set('teddy', 'with me')
+ await self.cog.redis.set('in my dreams', 'you have a weird hat')
+ self.assertEqual(await self.cog.redis.length(), 2)
+
+ await self.cog.redis.clear()
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_pop(self):
+ """Test that we can .pop an item from the RedisDict."""
+ await self.cog.redis.set('john', 'was afraid')
+
+ self.assertEqual(await self.cog.redis.pop('john'), 'was afraid')
+ self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck')
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_update(self):
+ """Test that we can .update the RedisDict with multiple items."""
+ await self.cog.redis.set("reckfried", "lona")
+ await self.cog.redis.set("bel air", "prince")
+ await self.cog.redis.update({
+ "reckfried": "jona",
+ "mega": "hungry, though",
+ })
+
+ result = {
+ "reckfried": "jona",
+ "bel air": "prince",
+ "mega": "hungry, though",
+ }
+ self.assertDictEqual(await self.cog.redis.to_dict(), result)
+
+ def test_typestring_conversion(self):
+ """Test the typestring-related helper functions."""
+ conversion_tests = (
+ (12, "i|12"),
+ (12.4, "f|12.4"),
+ ("cowabunga", "s|cowabunga"),
+ )
+
+ # Test conversion to typestring
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_to_typestring(_input), expected)
+
+ # Test conversion from typestrings
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_from_typestring(expected), _input)
+
+ # Test that exceptions are raised on invalid input
+ with self.assertRaises(TypeError):
+ self.cog.redis._value_to_typestring(["internet"])
+ self.cog.redis._value_from_typestring("o|firedog")
+
+ async def test_increment_decrement(self):
+ """Test .increment and .decrement methods."""
+ await self.cog.redis.set("entropic", 5)
+ await self.cog.redis.set("disentropic", 12.5)
+
+ # Test default increment
+ await self.cog.redis.increment("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 6)
+
+ # Test default decrement
+ await self.cog.redis.decrement("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # Test float increment with float
+ await self.cog.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 14.5)
+
+ # Test float increment with int
+ await self.cog.redis.increment("disentropic", 2)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 16.5)
+
+ # Test negative increments, because why not.
+ await self.cog.redis.increment("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 0)
+
+ # Negative decrements? Sure.
+ await self.cog.redis.decrement("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # What about if we use a negative float to decrement an int?
+ # This should convert the type into a float.
+ await self.cog.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 7.5)
+
+ # Let's test that they raise the right errors
+ with self.assertRaises(KeyError):
+ await self.cog.redis.increment("doesn't_exist!")
+
+ await self.cog.redis.set("stringthing", "stringthing")
+ with self.assertRaises(TypeError):
+ await self.cog.redis.increment("stringthing")
+
+ async def test_increment_lock(self):
+ """Test that we can't produce a race condition in .increment."""
+ await self.cog.redis.set("test_key", 0)
+ tasks = []
+
+ # Increment this a lot in different tasks
+ for _ in range(100):
+ task = asyncio.create_task(
+ self.cog.redis.increment("test_key", 1)
+ )
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ # Confirm that the value has been incremented the exact right number of times.
+ value = await self.cog.redis.get("test_key")
+ self.assertEqual(value, 100)
+
+ async def test_exceptions_raised(self):
+ """Testing that the various RuntimeErrors are reachable."""
+ class MyCog:
+ cache = RedisCache()
+
+ def __init__(self):
+ self.other_cache = RedisCache()
+
+ cog = MyCog()
+
+ # Raises "No Bot instance"
+ with self.assertRaises(NoBotInstanceError):
+ await cog.cache.get("john")
+
+ # Raises "RedisCache has no namespace"
+ with self.assertRaises(NoNamespaceError):
+ await cog.other_cache.get("was")
+
+ # Raises "You must access the RedisCache instance through the cog instance"
+ with self.assertRaises(NoParentInstanceError):
+ await MyCog.cache.get("afraid")
diff --git a/tests/helpers.py b/tests/helpers.py
index 2b79a6c2a..faa839370 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -4,12 +4,15 @@ import collections
import itertools
import logging
import unittest.mock
+from asyncio import AbstractEventLoop
from typing import Iterable, Optional
import discord
+from aiohttp import ClientSession
from discord.ext.commands import Context
from bot.api import APIClient
+from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
@@ -205,6 +208,10 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
"""Simplified position-based comparisons similar to those of `discord.Role`."""
return self.position < other.position
+ def __ge__(self, other):
+ """Simplified position-based comparisons similar to those of `discord.Role`."""
+ return self.position >= other.position
+
# Create a Member instance to get a realistic Mock of `discord.Member`
member_data = {'user': 'lemon', 'roles': [1]}
@@ -264,10 +271,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
spec_set = APIClient
-# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
-bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
-bot_instance.http_session = None
-bot_instance.api_client = None
+def _get_mock_loop() -> unittest.mock.Mock:
+ """Return a mocked asyncio.AbstractEventLoop."""
+ loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True)
+
+ # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
+ # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
+ # to prevent "has not been awaited"-warnings.
+ loop.create_task.side_effect = lambda coroutine: coroutine.close()
+
+ return loop
class MockBot(CustomMockMixin, unittest.mock.MagicMock):
@@ -277,17 +290,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
- spec_set = bot_instance
- additional_spec_asyncs = ("wait_for",)
+ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop())
+ additional_spec_asyncs = ("wait_for", "redis_ready")
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
- self.api_client = MockAPIClient()
- # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
- # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
- # to prevent "has not been awaited"-warnings.
- self.loop.create_task.side_effect = lambda coroutine: coroutine.close()
+ self.loop = _get_mock_loop()
+ self.api_client = MockAPIClient(loop=self.loop)
+ self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True)
+ self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)
# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`