aboutsummaryrefslogtreecommitdiffstats
path: root/tests/base.py
blob: 029a249ed3ef3d36cbd9d5a8a3e1445f10e3e29d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import logging
import unittest
from contextlib import contextmanager


class _CaptureLogHandler(logging.Handler):
    """
    A logging handler capturing all (raw and formatted) logging output.
    """

    def __init__(self):
        super().__init__()
        self.records = []

    def emit(self, record):
        self.records.append(record)


class LoggingTestCase(unittest.TestCase):
    """TestCase subclass that adds more logging assertion tools."""

    @contextmanager
    def assertNotLogs(self, logger=None, level=None, msg=None):
        """
        Asserts that no logs of `level` and higher were emitted by `logger`.

        You can specify a specific `logger`, the minimum `logging` level we want to watch and a
        custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the
        recorded log records will be outputted with the `AssertionError` message. The context
        manager does not yield a live `look` into the logging records, since we use this context
        manager when we're testing under the assumption that no log records will be emitted.
        """
        if not isinstance(logger, logging.Logger):
            logger = logging.getLogger(logger)

        if level:
            level = logging._nameToLevel.get(level, level)
        else:
            level = logging.INFO

        handler = _CaptureLogHandler()
        old_handlers = logger.handlers[:]
        old_level = logger.level
        old_propagate = logger.propagate

        logger.handlers = [handler]
        logger.setLevel(level)
        logger.propagate = False

        try:
            yield
        except Exception as exc:
            raise exc
        finally:
            logger.handlers = old_handlers
            logger.propagate = old_propagate
            logger.setLevel(old_level)

        if handler.records:
            level_name = logging.getLevelName(level)
            n_logs = len(handler.records)
            base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n"
            records = [str(record) for record in handler.records]
            record_message = "\n".join(records)
            standard_message = self._truncateMessage(base_message, record_message)
            msg = self._formatMessage(msg, standard_message)
            self.fail(msg)