diff options
author | 2020-06-19 14:32:31 +0200 | |
---|---|---|
committer | 2020-06-19 14:38:43 +0200 | |
commit | ed4097629601704f0c65fc40cceb5fd6757d4779 (patch) | |
tree | 5c8f47cd66363265ec34eca8b12d3922f93d0fbf | |
parent | Merge branch 'origin/master' into kwzrd/incidents (diff) |
Incidents tests: add helper for mocking async for-loops
See the docstring. This does not make the ambition to be powerful
enough to be included in `tests.helpers`, and is only intended
for local purposes.
-rw-r--r-- | tests/bot/cogs/moderation/test_incidents.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tests/bot/cogs/moderation/test_incidents.py b/tests/bot/cogs/moderation/test_incidents.py index 6158d5d20..7fa8847ef 100644 --- a/tests/bot/cogs/moderation/test_incidents.py +++ b/tests/bot/cogs/moderation/test_incidents.py @@ -1,6 +1,7 @@ import asyncio import enum import logging +import typing as t import unittest from unittest.mock import AsyncMock, MagicMock, call, patch @@ -20,6 +21,42 @@ from tests.helpers import ( ) +class MockAsyncIterable: + """ + Helper for mocking asynchronous for loops. + + It does not appear that the `unittest` library currently provides anything that would + allow us to simply mock an async iterator, such as `discord.TextChannel.history`. + + We therefore write our own helper to wrap a regular synchronous iterable, and feed + its values via `__anext__` rather than `__next__`. + + This class was written for the purposes of testing the `Incidents` cog - it may not + be generic enough to be placed in the `tests.helpers` module. + """ + + def __init__(self, messages: t.Iterable): + """Take a sync iterable to be wrapped.""" + self.iter_messages = iter(messages) + + def __aiter__(self): + """Return `self` as we provide the `__anext__` method.""" + return self + + async def __anext__(self): + """ + Feed the next item, or raise `StopAsyncIteration`. + + Since we're wrapping a sync iterator, it will communicate that it has been depleted + by raising a `StopIteration`. The `async for` construct does not expect it, and we + therefore need to substitute it for the appropriate exception type. + """ + try: + return next(self.iter_messages) + except StopIteration: + raise StopAsyncIteration + + class MockSignal(enum.Enum): A = "A" B = "B" |