aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Mark <[email protected]>2020-05-30 10:56:45 -0700
committerGravatar GitHub <[email protected]>2020-05-30 10:56:45 -0700
commit6ba4c90ad8fc42385e178caa8d3fb8f42dc0cd60 (patch)
treee9c099ec814983cf06463c2e41122bbebbde910e /tests
parentScheduler: Move space from f-string of `ctx.send` to `infr_message` (diff)
parentMerge pull request #930 from MrGrote/test_antimalware (diff)
Merge branch 'master' into ban-kick-reason-length
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/cogs/test_antimalware.py159
-rw-r--r--tests/bot/cogs/test_duck_pond.py2
-rw-r--r--tests/bot/cogs/test_information.py3
-rw-r--r--tests/bot/cogs/test_snekbox.py15
-rw-r--r--tests/bot/test_constants.py43
-rw-r--r--tests/bot/test_decorators.py4
-rw-r--r--tests/bot/utils/test_checks.py52
-rw-r--r--tests/bot/utils/test_redis_cache.py273
-rw-r--r--tests/helpers.py30
9 files changed, 550 insertions, 31 deletions
diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py
new file mode 100644
index 000000000..f219fc1ba
--- /dev/null
+++ b/tests/bot/cogs/test_antimalware.py
@@ -0,0 +1,159 @@
+import unittest
+from unittest.mock import AsyncMock, Mock, patch
+
+from discord import NotFound
+
+from bot.cogs import antimalware
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES
+from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole
+
+MODULE = "bot.cogs.antimalware"
+
+
+@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"])
+class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase):
+ """Test the AntiMalware cog."""
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = MockBot()
+ self.cog = antimalware.AntiMalware(self.bot)
+ self.message = MockMessage()
+
+ async def test_message_with_allowed_attachment(self):
+ """Messages with allowed extensions should not be deleted"""
+ attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_message_without_attachment(self):
+ """Messages without attachments should result in no action."""
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_not_called()
+
+ async def test_direct_message_with_attachment(self):
+ """Direct messages should have no action taken."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.guild = None
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_message_with_illegal_extension_gets_deleted(self):
+ """A message containing an illegal extension should send an embed."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_called_once()
+
+ async def test_message_send_by_staff(self):
+ """A message send by a member of staff should be ignored."""
+ staff_role = MockRole(id=STAFF_ROLES[0])
+ self.message.author.roles.append(staff_role)
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ await self.cog.on_message(self.message)
+
+ self.message.delete.assert_not_called()
+
+ async def test_python_file_redirect_embed_description(self):
+ """A message containing a .py file should result in an embed redirecting the user to our paste site"""
+ attachment = MockAttachment(filename="python.py")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+
+ self.assertEqual(embed.description, antimalware.PY_EMBED_DESCRIPTION)
+
+ async def test_txt_file_redirect_embed_description(self):
+ """A message containing a .txt file should result in the correct embed."""
+ attachment = MockAttachment(filename="python.txt")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.TXT_EMBED_DESCRIPTION = Mock()
+ antimalware.TXT_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ cmd_channel = self.bot.get_channel(Channels.bot_commands)
+
+ self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value)
+ antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention)
+
+ async def test_other_disallowed_extention_embed_description(self):
+ """Test the description for a non .py/.txt disallowed extension."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.channel.send = AsyncMock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION = Mock()
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value = "test"
+
+ await self.cog.on_message(self.message)
+ self.message.channel.send.assert_called_once()
+ args, kwargs = self.message.channel.send.call_args
+ embed = kwargs.pop("embed")
+ meta_channel = self.bot.get_channel(Channels.meta)
+
+ self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value)
+ antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with(
+ blocked_extensions_str=".disallowed",
+ meta_channel_mention=meta_channel.mention
+ )
+
+ async def test_removing_deleted_message_logs(self):
+ """Removing an already deleted message logs the correct message"""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+ self.message.delete = AsyncMock(side_effect=NotFound(response=Mock(status=""), message=""))
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+ self.message.delete.assert_called_once()
+
+ async def test_message_with_illegal_attachment_logs(self):
+ """Deleting a message with an illegal attachment should result in a log."""
+ attachment = MockAttachment(filename="python.disallowed")
+ self.message.attachments = [attachment]
+
+ with self.assertLogs(logger=antimalware.log, level="INFO"):
+ await self.cog.on_message(self.message)
+
+ async def test_get_disallowed_extensions(self):
+ """The return value should include all non-whitelisted extensions."""
+ test_values = (
+ ([], []),
+ (AntiMalwareConfig.whitelist, []),
+ ([".first"], []),
+ ([".first", ".disallowed"], [".disallowed"]),
+ ([".disallowed"], [".disallowed"]),
+ ([".disallowed", ".illegal"], [".disallowed", ".illegal"]),
+ )
+
+ for extensions, expected_disallowed_extensions in test_values:
+ with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions):
+ self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions]
+ disallowed_extensions = self.cog.get_disallowed_extensions(self.message)
+ self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions)
+
+
+class AntiMalwareSetupTests(unittest.TestCase):
+ """Tests setup of the `AntiMalware` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ antimalware.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 7e6bfc748..a8c0107c6 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -45,7 +45,7 @@ class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
self.assertEqual(cog.bot, bot)
self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond)
- bot.loop.create_loop.called_once_with(cog.fetch_webhook())
+ bot.loop.create_task.assert_called_once_with(cog.fetch_webhook())
def test_fetch_webhook_succeeds_without_connectivity_issues(self):
"""The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute."""
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index b5f928dd6..aca6b594f 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -7,10 +7,9 @@ import discord
from bot import constants
from bot.cogs import information
-from bot.decorators import InWhitelistCheckFailure
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
COG_PATH = "bot.cogs.information.Information"
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
index 14299e766..cf9adbee0 100644
--- a/tests/bot/cogs/test_snekbox.py
+++ b/tests/bot/cogs/test_snekbox.py
@@ -21,7 +21,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Post the eval code to the URLs.snekbox_eval_api endpoint."""
resp = MagicMock()
resp.json = AsyncMock(return_value="return")
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
self.assertEqual(await self.cog.post_eval("import random"), "return")
self.bot.http_session.post.assert_called_with(
@@ -41,7 +44,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
key = "MarkDiamond"
resp = MagicMock()
resp.json = AsyncMock(return_value={"key": key})
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
self.assertEqual(
await self.cog.upload_output("My awesome output"),
@@ -57,7 +63,10 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
"""Output upload gracefully fallback if the upload fail."""
resp = MagicMock()
resp.json = AsyncMock(side_effect=Exception)
- self.bot.http_session.post().__aenter__.return_value = resp
+
+ context_manager = MagicMock()
+ context_manager.__aenter__.return_value = resp
+ self.bot.http_session.post.return_value = context_manager
log = logging.getLogger("bot.cogs.snekbox")
with self.assertLogs(logger=log, level='ERROR'):
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
index dae7c066c..f10d6fbe8 100644
--- a/tests/bot/test_constants.py
+++ b/tests/bot/test_constants.py
@@ -1,14 +1,40 @@
import inspect
+import typing
import unittest
from bot import constants
+def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool:
+ """
+ Return True if `value` is an instance of the type represented by `annotation`.
+
+ This doesn't account for things like Unions or checking for homogenous types in collections.
+ """
+ origin = typing.get_origin(annotation)
+
+ # This is done in case a bare e.g. `typing.List` is used.
+ # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`.
+ # `get_origin()` does this normalisation for us.
+ type_ = annotation if origin is None else origin
+
+ return isinstance(value, type_)
+
+
+def is_any_instance(value: typing.Any, types: typing.Collection) -> bool:
+ """Return True if `value` is an instance of any type in `types`."""
+ for type_ in types:
+ if is_annotation_instance(value, type_):
+ return True
+
+ return False
+
+
class ConstantsTests(unittest.TestCase):
"""Tests for our constants."""
def test_section_configuration_matches_type_specification(self):
- """The section annotations should match the actual types of the sections."""
+ """"The section annotations should match the actual types of the sections."""
sections = (
cls
@@ -17,10 +43,15 @@ class ConstantsTests(unittest.TestCase):
)
for section in sections:
for name, annotation in section.__annotations__.items():
- with self.subTest(section=section, name=name, annotation=annotation):
+ with self.subTest(section=section.__name__, name=name, annotation=annotation):
value = getattr(section, name)
+ origin = typing.get_origin(annotation)
+ annotation_args = typing.get_args(annotation)
+ failure_msg = f"{value} is not an instance of {annotation}"
- if getattr(annotation, '_name', None) in ('Dict', 'List'):
- self.skipTest("Cannot validate containers yet.")
-
- self.assertIsInstance(value, annotation)
+ if origin is typing.Union:
+ is_instance = is_any_instance(value, annotation_args)
+ self.assertTrue(is_instance, failure_msg)
+ else:
+ is_instance = is_annotation_instance(value, annotation)
+ self.assertTrue(is_instance, failure_msg)
diff --git a/tests/bot/test_decorators.py b/tests/bot/test_decorators.py
index a17dd3e16..3d450caa0 100644
--- a/tests/bot/test_decorators.py
+++ b/tests/bot/test_decorators.py
@@ -3,10 +3,10 @@ import unittest
import unittest.mock
from bot import constants
-from bot.decorators import InWhitelistCheckFailure, in_whitelist
+from bot.decorators import in_whitelist
+from bot.utils.checks import InWhitelistCheckFailure
from tests import helpers
-
InWhitelistTestCase = collections.namedtuple("WhitelistedContextTestCase", ("kwargs", "ctx", "description"))
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
index 9610771e5..de72e5748 100644
--- a/tests/bot/utils/test_checks.py
+++ b/tests/bot/utils/test_checks.py
@@ -1,6 +1,8 @@
import unittest
+from unittest.mock import MagicMock
from bot.utils import checks
+from bot.utils.checks import InWhitelistCheckFailure
from tests.helpers import MockContext, MockRole
@@ -42,10 +44,48 @@ class ChecksTests(unittest.TestCase):
self.ctx.author.roles.append(MockRole(id=role_id))
self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
- def test_in_channel_check_for_correct_channel(self):
- self.ctx.channel.id = 42
- self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_correct_channel(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.id` is in the channel list."""
+ channel_id = 3
+ self.ctx.channel.id = channel_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, [channel_id]))
- def test_in_channel_check_for_incorrect_channel(self):
- self.ctx.channel.id = 42 + 10
- self.assertFalse(checks.in_channel_check(self.ctx, *[42]))
+ def test_in_whitelist_check_incorrect_channel(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no channel match."""
+ self.ctx.channel.id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, [4])
+
+ def test_in_whitelist_check_correct_category(self):
+ """`in_whitelist_check` returns `True` if `Context.channel.category_id` is in the category list."""
+ category_id = 3
+ self.ctx.channel.category_id = category_id
+ self.assertTrue(checks.in_whitelist_check(self.ctx, categories=[category_id]))
+
+ def test_in_whitelist_check_incorrect_category(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no category match."""
+ self.ctx.channel.category_id = 3
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, categories=[4])
+
+ def test_in_whitelist_check_correct_role(self):
+ """`in_whitelist_check` returns `True` if any of the `Context.author.roles` are in the roles list."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.assertTrue(checks.in_whitelist_check(self.ctx, roles=[2, 6]))
+
+ def test_in_whitelist_check_incorrect_role(self):
+ """`in_whitelist_check` raises InWhitelistCheckFailure if there's no role match."""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ with self.assertRaises(InWhitelistCheckFailure):
+ checks.in_whitelist_check(self.ctx, roles=[4])
+
+ def test_in_whitelist_check_fail_silently(self):
+ """`in_whitelist_check` test no exception raised if `fail_silently` is `True`"""
+ self.assertFalse(checks.in_whitelist_check(self.ctx, roles=[2, 6], fail_silently=True))
+
+ def test_in_whitelist_check_complex(self):
+ """`in_whitelist_check` test with multiple parameters"""
+ self.ctx.author.roles = (MagicMock(id=1), MagicMock(id=2))
+ self.ctx.channel.category_id = 3
+ self.ctx.channel.id = 5
+ self.assertTrue(checks.in_whitelist_check(self.ctx, channels=[1], categories=[8], roles=[2]))
diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py
new file mode 100644
index 000000000..8c1a40640
--- /dev/null
+++ b/tests/bot/utils/test_redis_cache.py
@@ -0,0 +1,273 @@
+import asyncio
+import unittest
+
+import fakeredis.aioredis
+
+from bot.utils import RedisCache
+from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError
+from tests import helpers
+
+
+class RedisCacheTests(unittest.IsolatedAsyncioTestCase):
+ """Tests the RedisCache class from utils.redis_dict.py."""
+
+ async def asyncSetUp(self): # noqa: N802
+ """Sets up the objects that only have to be initialized once."""
+ self.bot = helpers.MockBot()
+ self.bot.redis_session = await fakeredis.aioredis.create_redis_pool()
+
+ # Okay, so this is necessary so that we can create a clean new
+ # class for every test method, and we want that because it will
+ # ensure we get a fresh loop, which is necessary for test_increment_lock
+ # to be able to pass.
+ class DummyCog:
+ """A dummy cog, for dummies."""
+
+ redis = RedisCache()
+
+ def __init__(self, bot: helpers.MockBot):
+ self.bot = bot
+
+ self.cog = DummyCog(self.bot)
+
+ await self.cog.redis.clear()
+
+ def test_class_attribute_namespace(self):
+ """Test that RedisDict creates a namespace automatically for class attributes."""
+ self.assertEqual(self.cog.redis._namespace, "DummyCog.redis")
+
+ async def test_class_attribute_required(self):
+ """Test that errors are raised when not assigned as a class attribute."""
+ bad_cache = RedisCache()
+ self.assertIs(bad_cache._namespace, None)
+
+ with self.assertRaises(RuntimeError):
+ await bad_cache.set("test", "me_up_deadman")
+
+ def test_namespace_collision(self):
+ """Test that we prevent colliding namespaces."""
+ bob_cache_1 = RedisCache()
+ bob_cache_1._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_1._namespace, "BobRoss")
+
+ bob_cache_2 = RedisCache()
+ bob_cache_2._set_namespace("BobRoss")
+ self.assertEqual(bob_cache_2._namespace, "BobRoss_")
+
+ async def test_set_get_item(self):
+ """Test that users can set and get items from the RedisDict."""
+ test_cases = (
+ ('favorite_fruit', 'melon'),
+ ('favorite_number', 86),
+ ('favorite_fraction', 86.54)
+ )
+
+ # Test that we can get and set different types.
+ for test in test_cases:
+ await self.cog.redis.set(*test)
+ self.assertEqual(await self.cog.redis.get(test[0]), test[1])
+
+ # Test that .get allows a default value
+ self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw")
+
+ async def test_set_item_type(self):
+ """Test that .set rejects keys and values that are not permitted."""
+ fruits = ["lemon", "melon", "apple"]
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(fruits, "nice")
+
+ with self.assertRaises(TypeError):
+ await self.cog.redis.set(4.23, "nice")
+
+ async def test_delete_item(self):
+ """Test that .delete allows us to delete stuff from the RedisCache."""
+ # Add an item and verify that it gets added
+ await self.cog.redis.set("internet", "firetruck")
+ self.assertEqual(await self.cog.redis.get("internet"), "firetruck")
+
+ # Delete that item and verify that it gets deleted
+ await self.cog.redis.delete("internet")
+ self.assertIs(await self.cog.redis.get("internet"), None)
+
+ async def test_contains(self):
+ """Test that we can check membership with .contains."""
+ await self.cog.redis.set('favorite_country', "Burkina Faso")
+
+ self.assertIs(await self.cog.redis.contains('favorite_country'), True)
+ self.assertIs(await self.cog.redis.contains('favorite_dentist'), False)
+
+ async def test_items(self):
+ """Test that the RedisDict can be iterated."""
+ # Set up our test cases in the Redis cache
+ test_cases = [
+ ('favorite_turtle', 'Donatello'),
+ ('second_favorite_turtle', 'Leonardo'),
+ ('third_favorite_turtle', 'Raphael'),
+ ]
+ for key, value in test_cases:
+ await self.cog.redis.set(key, value)
+
+ # Consume the AsyncIterator into a regular list, easier to compare that way.
+ redis_items = [item for item in await self.cog.redis.items()]
+
+ # These sequences are probably in the same order now, but probably
+ # isn't good enough for tests. Let's not rely on .hgetall always
+ # returning things in sequence, and just sort both lists to be safe.
+ redis_items = sorted(redis_items)
+ test_cases = sorted(test_cases)
+
+ # If these are equal now, everything works fine.
+ self.assertSequenceEqual(test_cases, redis_items)
+
+ async def test_length(self):
+ """Test that we can get the correct .length from the RedisDict."""
+ await self.cog.redis.set('one', 1)
+ await self.cog.redis.set('two', 2)
+ await self.cog.redis.set('three', 3)
+ self.assertEqual(await self.cog.redis.length(), 3)
+
+ await self.cog.redis.set('four', 4)
+ self.assertEqual(await self.cog.redis.length(), 4)
+
+ async def test_to_dict(self):
+ """Test that the .to_dict method returns a workable dictionary copy."""
+ copy = await self.cog.redis.to_dict()
+ local_copy = {key: value for key, value in await self.cog.redis.items()}
+ self.assertIs(type(copy), dict)
+ self.assertDictEqual(copy, local_copy)
+
+ async def test_clear(self):
+ """Test that the .clear method removes the entire hash."""
+ await self.cog.redis.set('teddy', 'with me')
+ await self.cog.redis.set('in my dreams', 'you have a weird hat')
+ self.assertEqual(await self.cog.redis.length(), 2)
+
+ await self.cog.redis.clear()
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_pop(self):
+ """Test that we can .pop an item from the RedisDict."""
+ await self.cog.redis.set('john', 'was afraid')
+
+ self.assertEqual(await self.cog.redis.pop('john'), 'was afraid')
+ self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck')
+ self.assertEqual(await self.cog.redis.length(), 0)
+
+ async def test_update(self):
+ """Test that we can .update the RedisDict with multiple items."""
+ await self.cog.redis.set("reckfried", "lona")
+ await self.cog.redis.set("bel air", "prince")
+ await self.cog.redis.update({
+ "reckfried": "jona",
+ "mega": "hungry, though",
+ })
+
+ result = {
+ "reckfried": "jona",
+ "bel air": "prince",
+ "mega": "hungry, though",
+ }
+ self.assertDictEqual(await self.cog.redis.to_dict(), result)
+
+ def test_typestring_conversion(self):
+ """Test the typestring-related helper functions."""
+ conversion_tests = (
+ (12, "i|12"),
+ (12.4, "f|12.4"),
+ ("cowabunga", "s|cowabunga"),
+ )
+
+ # Test conversion to typestring
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_to_typestring(_input), expected)
+
+ # Test conversion from typestrings
+ for _input, expected in conversion_tests:
+ self.assertEqual(self.cog.redis._value_from_typestring(expected), _input)
+
+ # Test that exceptions are raised on invalid input
+ with self.assertRaises(TypeError):
+ self.cog.redis._value_to_typestring(["internet"])
+ self.cog.redis._value_from_typestring("o|firedog")
+
+ async def test_increment_decrement(self):
+ """Test .increment and .decrement methods."""
+ await self.cog.redis.set("entropic", 5)
+ await self.cog.redis.set("disentropic", 12.5)
+
+ # Test default increment
+ await self.cog.redis.increment("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 6)
+
+ # Test default decrement
+ await self.cog.redis.decrement("entropic")
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # Test float increment with float
+ await self.cog.redis.increment("disentropic", 2.0)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 14.5)
+
+ # Test float increment with int
+ await self.cog.redis.increment("disentropic", 2)
+ self.assertEqual(await self.cog.redis.get("disentropic"), 16.5)
+
+ # Test negative increments, because why not.
+ await self.cog.redis.increment("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 0)
+
+ # Negative decrements? Sure.
+ await self.cog.redis.decrement("entropic", -5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 5)
+
+ # What about if we use a negative float to decrement an int?
+ # This should convert the type into a float.
+ await self.cog.redis.decrement("entropic", -2.5)
+ self.assertEqual(await self.cog.redis.get("entropic"), 7.5)
+
+ # Let's test that they raise the right errors
+ with self.assertRaises(KeyError):
+ await self.cog.redis.increment("doesn't_exist!")
+
+ await self.cog.redis.set("stringthing", "stringthing")
+ with self.assertRaises(TypeError):
+ await self.cog.redis.increment("stringthing")
+
+ async def test_increment_lock(self):
+ """Test that we can't produce a race condition in .increment."""
+ await self.cog.redis.set("test_key", 0)
+ tasks = []
+
+ # Increment this a lot in different tasks
+ for _ in range(100):
+ task = asyncio.create_task(
+ self.cog.redis.increment("test_key", 1)
+ )
+ tasks.append(task)
+ await asyncio.gather(*tasks)
+
+ # Confirm that the value has been incremented the exact right number of times.
+ value = await self.cog.redis.get("test_key")
+ self.assertEqual(value, 100)
+
+ async def test_exceptions_raised(self):
+ """Testing that the various RuntimeErrors are reachable."""
+ class MyCog:
+ cache = RedisCache()
+
+ def __init__(self):
+ self.other_cache = RedisCache()
+
+ cog = MyCog()
+
+ # Raises "No Bot instance"
+ with self.assertRaises(NoBotInstanceError):
+ await cog.cache.get("john")
+
+ # Raises "RedisCache has no namespace"
+ with self.assertRaises(NoNamespaceError):
+ await cog.other_cache.get("was")
+
+ # Raises "You must access the RedisCache instance through the cog instance"
+ with self.assertRaises(NoParentInstanceError):
+ await MyCog.cache.get("afraid")
diff --git a/tests/helpers.py b/tests/helpers.py
index 91d814b3a..faa839370 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -4,12 +4,15 @@ import collections
import itertools
import logging
import unittest.mock
+from asyncio import AbstractEventLoop
from typing import Iterable, Optional
import discord
+from aiohttp import ClientSession
from discord.ext.commands import Context
from bot.api import APIClient
+from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
@@ -268,10 +271,16 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
spec_set = APIClient
-# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
-bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
-bot_instance.http_session = None
-bot_instance.api_client = None
+def _get_mock_loop() -> unittest.mock.Mock:
+ """Return a mocked asyncio.AbstractEventLoop."""
+ loop = unittest.mock.create_autospec(spec=AbstractEventLoop, spec_set=True)
+
+ # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
+ # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
+ # to prevent "has not been awaited"-warnings.
+ loop.create_task.side_effect = lambda coroutine: coroutine.close()
+
+ return loop
class MockBot(CustomMockMixin, unittest.mock.MagicMock):
@@ -281,17 +290,16 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
- spec_set = bot_instance
- additional_spec_asyncs = ("wait_for",)
+ spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop())
+ additional_spec_asyncs = ("wait_for", "redis_ready")
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
- self.api_client = MockAPIClient()
- # Since calling `create_task` on our MockBot does not actually schedule the coroutine object
- # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
- # to prevent "has not been awaited"-warnings.
- self.loop.create_task.side_effect = lambda coroutine: coroutine.close()
+ self.loop = _get_mock_loop()
+ self.api_client = MockAPIClient(loop=self.loop)
+ self.http_session = unittest.mock.create_autospec(spec=ClientSession, spec_set=True)
+ self.stats = unittest.mock.create_autospec(spec=AsyncStatsClient, spec_set=True)
# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`