1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
import logging
import unittest
from contextlib import contextmanager
from typing import Dict
import discord
from async_rediscache import RedisSession
from discord.ext import commands
from bot.log import get_logger
from tests import helpers
class _CaptureLogHandler(logging.Handler):
"""
A logging handler capturing all (raw and formatted) logging output.
"""
def __init__(self):
super().__init__()
self.records = []
def emit(self, record):
self.records.append(record)
class LoggingTestsMixin:
"""
A mixin that defines additional test methods for logging behavior.
This mixin relies on the availability of the `fail` attribute defined by the
test classes included in Python's unittest method to signal test failure.
"""
@contextmanager
def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802
"""
Asserts that no logs of `level` and higher were emitted by `logger`.
You can specify a specific `logger`, the minimum `logging` level we want to watch and a
custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the
recorded log records will be outputted with the `AssertionError` message. The context
manager does not yield a live `look` into the logging records, since we use this context
manager when we're testing under the assumption that no log records will be emitted.
"""
if not isinstance(logger, logging.Logger):
logger = get_logger(logger)
if level:
level = logging._nameToLevel.get(level, level)
else:
level = logging.INFO
handler = _CaptureLogHandler()
old_handlers = logger.handlers[:]
old_level = logger.level
old_propagate = logger.propagate
logger.handlers = [handler]
logger.setLevel(level)
logger.propagate = False
try:
yield
except Exception as exc:
raise exc
finally:
logger.handlers = old_handlers
logger.propagate = old_propagate
logger.setLevel(old_level)
if handler.records:
level_name = logging.getLevelName(level)
n_logs = len(handler.records)
base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n"
records = [str(record) for record in handler.records]
record_message = "\n".join(records)
standard_message = self._truncateMessage(base_message, record_message)
msg = self._formatMessage(msg, standard_message)
self.fail(msg)
class CommandTestCase(unittest.IsolatedAsyncioTestCase):
"""TestCase with additional assertions that are useful for testing Discord commands."""
async def assertHasPermissionsCheck( # noqa: N802
self,
cmd: commands.Command,
permissions: Dict[str, bool],
) -> None:
"""
Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`.
Every permission in `permissions` is expected to be reported as missing. In other words, do
not include permissions which should not raise an exception along with those which should.
"""
# Invert permission values because it's more intuitive to pass to this assertion the same
# permissions as those given to the check decorator.
permissions = {k: not v for k, v in permissions.items()}
ctx = helpers.MockContext()
ctx.channel.permissions_for.return_value = discord.Permissions(**permissions)
with self.assertRaises(commands.MissingPermissions) as cm:
await cmd.can_run(ctx)
self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions)
class RedisTestCase(unittest.IsolatedAsyncioTestCase):
"""
Use this as a base class for any test cases that require a redis session.
This will prepare a fresh redis instance for each test function, and will
not make any assertions on its own. Tests can mutate the instance as they wish.
"""
session = None
async def flush(self):
"""Flush everything from the redis database to prevent carry-overs between tests."""
await self.session.client.flushall()
async def asyncSetUp(self):
self.session = await RedisSession(use_fakeredis=True).connect()
await self.flush()
async def asyncTearDown(self):
if self.session:
await self.session.client.close()
|