aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2022-08-14 19:53:25 +0100
committerGravatar GitHub <[email protected]>2022-08-14 19:53:25 +0100
commit02cff856c9159a8e1acf4fb502405fe30d2c9f68 (patch)
treebc09829a0af5c164546e34565e146f5793ac416e /tests
parentMerge pull request #2240 from python-discord/2238-purge-cmd (diff)
parentrevert bump to markdownify version (diff)
Merge pull request #2229 from python-discord/py3.10-rediscache
Diffstat (limited to 'tests')
-rw-r--r--tests/base.py24
-rw-r--r--tests/bot/exts/backend/test_error_handler.py103
-rw-r--r--tests/bot/exts/moderation/test_incidents.py22
-rw-r--r--tests/bot/exts/moderation/test_silence.py30
4 files changed, 78 insertions, 101 deletions
diff --git a/tests/base.py b/tests/base.py
index 5e304ea9d..4863a1821 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import Dict
import discord
+from async_rediscache import RedisSession
from discord.ext import commands
from bot.log import get_logger
@@ -104,3 +105,26 @@ class CommandTestCase(unittest.IsolatedAsyncioTestCase):
await cmd.can_run(ctx)
self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions)
+
+
+class RedisTestCase(unittest.IsolatedAsyncioTestCase):
+ """
+ Use this as a base class for any test cases that require a redis session.
+
+ This will prepare a fresh redis instance for each test function, and will
+ not make any assertions on its own. Tests can mutate the instance as they wish.
+ """
+
+ session = None
+
+ async def flush(self):
+ """Flush everything from the redis database to prevent carry-overs between tests."""
+ await self.session.client.flushall()
+
+ async def asyncSetUp(self):
+ self.session = await RedisSession(use_fakeredis=True).connect()
+ await self.flush()
+
+ async def asyncTearDown(self):
+ if self.session:
+ await self.session.client.close()
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 0a58126e7..7562f6aa8 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -5,7 +5,7 @@ from botcore.site_api import ResponseCodeError
from discord.ext.commands import errors
from bot.errors import InvalidInfractedUserError, LockedResourceError
-from bot.exts.backend.error_handler import ErrorHandler, setup
+from bot.exts.backend import error_handler
from bot.exts.info.tags import Tags
from bot.exts.moderation.silence import Silence
from bot.utils.checks import InWhitelistCheckFailure
@@ -18,14 +18,14 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = MockBot()
self.ctx = MockContext(bot=self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_error_handler_already_handled(self):
"""Should not do anything when error is already handled by local error handler."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
error = errors.CommandError()
error.handled = "foo"
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
self.ctx.send.assert_not_awaited()
async def test_error_handler_command_not_found_error_not_invoked_by_handler(self):
@@ -45,28 +45,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
"called_try_get_tag": True
}
)
- cog = ErrorHandler(self.bot)
- cog.try_silence = AsyncMock()
- cog.try_get_tag = AsyncMock()
- cog.try_run_eval = AsyncMock(return_value=False)
+ self.cog.try_silence = AsyncMock()
+ self.cog.try_get_tag = AsyncMock()
+ self.cog.try_run_eval = AsyncMock(return_value=False)
for case in test_cases:
with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):
self.ctx.reset_mock()
- cog.try_silence.reset_mock(return_value=True)
- cog.try_get_tag.reset_mock()
+ self.cog.try_silence.reset_mock(return_value=True)
+ self.cog.try_get_tag.reset_mock()
- cog.try_silence.return_value = case["try_silence_return"]
+ self.cog.try_silence.return_value = case["try_silence_return"]
self.ctx.channel.id = 1234
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
if case["try_silence_return"]:
- cog.try_get_tag.assert_not_awaited()
- cog.try_silence.assert_awaited_once()
+ self.cog.try_get_tag.assert_not_awaited()
+ self.cog.try_silence.assert_awaited_once()
else:
- cog.try_silence.assert_awaited_once()
- cog.try_get_tag.assert_awaited_once()
+ self.cog.try_silence.assert_awaited_once()
+ self.cog.try_get_tag.assert_awaited_once()
self.ctx.send.assert_not_awaited()
@@ -74,59 +73,54 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
"""Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""
ctx = MockContext(bot=self.bot, invoked_from_error_handler=True)
- cog = ErrorHandler(self.bot)
- cog.try_silence = AsyncMock()
- cog.try_get_tag = AsyncMock()
- cog.try_run_eval = AsyncMock()
+ self.cog.try_silence = AsyncMock()
+ self.cog.try_get_tag = AsyncMock()
+ self.cog.try_run_eval = AsyncMock()
error = errors.CommandNotFound()
- self.assertIsNone(await cog.on_command_error(ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(ctx, error))
- cog.try_silence.assert_not_awaited()
- cog.try_get_tag.assert_not_awaited()
- cog.try_run_eval.assert_not_awaited()
+ self.cog.try_silence.assert_not_awaited()
+ self.cog.try_get_tag.assert_not_awaited()
+ self.cog.try_run_eval.assert_not_awaited()
self.ctx.send.assert_not_awaited()
async def test_error_handler_user_input_error(self):
"""Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
- cog.handle_user_input_error = AsyncMock()
+ self.cog.handle_user_input_error = AsyncMock()
error = errors.UserInputError()
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
- cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
+ self.cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
async def test_error_handler_check_failure(self):
"""Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
- cog.handle_check_failure = AsyncMock()
+ self.cog.handle_check_failure = AsyncMock()
error = errors.CheckFailure()
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
- cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
+ self.cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
async def test_error_handler_command_on_cooldown(self):
"""Should send error with `ctx.send` when error is `CommandOnCooldown`."""
self.ctx.reset_mock()
- cog = ErrorHandler(self.bot)
error = errors.CommandOnCooldown(10, 9, type=None)
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
self.ctx.send.assert_awaited_once_with(error)
async def test_error_handler_command_invoke_error(self):
"""Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
- cog = ErrorHandler(self.bot)
- cog.handle_api_error = AsyncMock()
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_api_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
test_cases = (
{
"args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))),
- "expect_mock_call": cog.handle_api_error
+ "expect_mock_call": self.cog.handle_api_error
},
{
"args": (self.ctx, errors.CommandInvokeError(TypeError)),
- "expect_mock_call": cog.handle_unexpected_error
+ "expect_mock_call": self.cog.handle_unexpected_error
},
{
"args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))),
@@ -141,7 +135,7 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for case in test_cases:
with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):
self.ctx.send.reset_mock()
- self.assertIsNone(await cog.on_command_error(*case["args"]))
+ self.assertIsNone(await self.cog.on_command_error(*case["args"]))
if case["expect_mock_call"] == "send":
self.ctx.send.assert_awaited_once()
else:
@@ -151,29 +145,27 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
async def test_error_handler_conversion_error(self):
"""Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
- cog = ErrorHandler(self.bot)
- cog.handle_api_error = AsyncMock()
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_api_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
cases = (
{
"error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())),
- "mock_function_to_call": cog.handle_api_error
+ "mock_function_to_call": self.cog.handle_api_error
},
{
"error": errors.ConversionError(AsyncMock(), TypeError),
- "mock_function_to_call": cog.handle_unexpected_error
+ "mock_function_to_call": self.cog.handle_unexpected_error
}
)
for case in cases:
with self.subTest(**case):
- self.assertIsNone(await cog.on_command_error(self.ctx, case["error"]))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, case["error"]))
case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)
async def test_error_handler_two_other_errors(self):
"""Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`."""
- cog = ErrorHandler(self.bot)
- cog.handle_unexpected_error = AsyncMock()
+ self.cog.handle_unexpected_error = AsyncMock()
errs = (
errors.MaxConcurrencyReached(1, MagicMock()),
errors.ExtensionError(name="foo")
@@ -181,16 +173,15 @@ class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for err in errs:
with self.subTest(error=err):
- cog.handle_unexpected_error.reset_mock()
- self.assertIsNone(await cog.on_command_error(self.ctx, err))
- cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
+ self.cog.handle_unexpected_error.reset_mock()
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, err))
+ self.cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
@patch("bot.exts.backend.error_handler.log")
async def test_error_handler_other_errors(self, log_mock):
"""Should `log.debug` other errors."""
- cog = ErrorHandler(self.bot)
error = errors.DisabledCommand() # Use this just as a other error
- self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.assertIsNone(await self.cog.on_command_error(self.ctx, error))
log_mock.debug.assert_called_once()
@@ -202,7 +193,7 @@ class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
self.silence = Silence(self.bot)
self.bot.get_command.return_value = self.silence.silence
self.ctx = MockContext(bot=self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_try_silence_context_invoked_from_error_handler(self):
"""Should set `Context.invoked_from_error_handler` to `True`."""
@@ -334,7 +325,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
self.bot = MockBot()
self.ctx = MockContext()
self.tag = Tags(self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
self.bot.get_command.return_value = self.tag.get_command
async def test_try_get_tag_get_command(self):
@@ -399,7 +390,7 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = MockBot()
self.ctx = MockContext(bot=self.bot)
- self.cog = ErrorHandler(self.bot)
+ self.cog = error_handler.ErrorHandler(self.bot)
async def test_handle_input_error_handler_errors(self):
"""Should handle each error probably."""
@@ -555,5 +546,5 @@ class ErrorHandlerSetupTests(unittest.IsolatedAsyncioTestCase):
async def test_setup(self):
"""Should call `bot.add_cog` with `ErrorHandler`."""
bot = MockBot()
- await setup(bot)
+ await error_handler.setup(bot)
bot.add_cog.assert_awaited_once()
diff --git a/tests/bot/exts/moderation/test_incidents.py b/tests/bot/exts/moderation/test_incidents.py
index cfe0c4b03..97682163f 100644
--- a/tests/bot/exts/moderation/test_incidents.py
+++ b/tests/bot/exts/moderation/test_incidents.py
@@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
import aiohttp
import discord
-from async_rediscache import RedisSession
from bot.constants import Colours
from bot.exts.moderation import incidents
from bot.utils.messages import format_user
+from tests.base import RedisTestCase
from tests.helpers import (
MockAsyncWebhook, MockAttachment, MockBot, MockMember, MockMessage, MockReaction, MockRole, MockTextChannel,
MockUser
@@ -270,7 +270,7 @@ class TestAddSignals(unittest.IsolatedAsyncioTestCase):
self.incident.add_reaction.assert_not_called()
-class TestIncidents(unittest.IsolatedAsyncioTestCase):
+class TestIncidents(RedisTestCase):
"""
Tests for bound methods of the `Incidents` cog.
@@ -279,22 +279,6 @@ class TestIncidents(unittest.IsolatedAsyncioTestCase):
the instance as they wish.
"""
- session = None
-
- async def flush(self):
- """Flush everything from the database to prevent carry-overs between tests."""
- with await self.session.pool as connection:
- await connection.flushall()
-
- async def asyncSetUp(self): # noqa: N802
- self.session = RedisSession(use_fakeredis=True)
- await self.session.connect()
- await self.flush()
-
- async def asyncTearDown(self): # noqa: N802
- if self.session:
- await self.session.close()
-
def setUp(self):
"""
Prepare a fresh `Incidents` instance for each test.
@@ -656,7 +640,7 @@ class TestOnRawReactionAdd(TestIncidents):
emoji="reaction",
)
- async def asyncSetUp(self): # noqa: N802
+ async def asyncSetUp(self):
"""
Prepare an empty task and assign it as `crawl_task`.
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 65aecad28..98547e2bc 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -6,31 +6,15 @@ from typing import List, Tuple
from unittest import mock
from unittest.mock import AsyncMock, Mock
-from async_rediscache import RedisSession
from discord import PermissionOverwrite
from bot.constants import Channels, Guild, MODERATION_ROLES, Roles
from bot.exts.moderation import silence
+from tests.base import RedisTestCase
from tests.helpers import (
MockBot, MockContext, MockGuild, MockMember, MockRole, MockTextChannel, MockVoiceChannel, autospec
)
-redis_session = None
-redis_loop = asyncio.get_event_loop()
-
-
-def setUpModule(): # noqa: N802
- """Create and connect to the fakeredis session."""
- global redis_session
- redis_session = RedisSession(use_fakeredis=True)
- redis_loop.run_until_complete(redis_session.connect())
-
-
-def tearDownModule(): # noqa: N802
- """Close the fakeredis session."""
- if redis_session:
- redis_loop.run_until_complete(redis_session.close())
-
# Have to subclass it because builtins can't be patched.
class PatchedDatetime(datetime):
@@ -105,7 +89,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
-class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
+class SilenceCogTests(RedisTestCase):
"""Tests for the general functionality of the Silence cog."""
@autospec(silence, "Scheduler", pass_mocks=False)
@@ -245,14 +229,12 @@ class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(member.move_to.call_count, 1 if member == failing_member else 2)
-class SilenceArgumentParserTests(unittest.IsolatedAsyncioTestCase):
+class SilenceArgumentParserTests(RedisTestCase):
"""Tests for the silence argument parser utility function."""
def setUp(self):
self.bot = MockBot()
self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
@autospec(silence.Silence, "send_message", pass_mocks=False)
@autospec(silence.Silence, "_set_silence_overwrites", return_value=False, pass_mocks=False)
@@ -406,7 +388,7 @@ def voice_sync_helper(function):
@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
-class SilenceTests(unittest.IsolatedAsyncioTestCase):
+class SilenceTests(RedisTestCase):
"""Tests for the silence command and its related helper methods."""
@autospec(silence.Silence, "_reschedule", pass_mocks=False)
@@ -414,8 +396,6 @@ class SilenceTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.bot = MockBot(get_channel=lambda _: MockTextChannel())
self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
# Avoid unawaited coroutine warnings.
self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close()
@@ -687,8 +667,6 @@ class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.bot = MockBot(get_channel=lambda _: MockTextChannel())
self.cog = silence.Silence(self.bot)
- self.cog._init_task = asyncio.Future()
- self.cog._init_task.set_result(None)
overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
self.cog.previous_overwrites = overwrites_cache