aboutsummaryrefslogtreecommitdiffstats
path: root/tests/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/base.py')
-rw-r--r--tests/base.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/tests/base.py b/tests/base.py
new file mode 100644
index 000000000..029a249ed
--- /dev/null
+++ b/tests/base.py
@@ -0,0 +1,67 @@
+import logging
+import unittest
+from contextlib import contextmanager
+
+
+class _CaptureLogHandler(logging.Handler):
+ """
+ A logging handler capturing all (raw and formatted) logging output.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.records = []
+
+ def emit(self, record):
+ self.records.append(record)
+
+
+class LoggingTestCase(unittest.TestCase):
+ """TestCase subclass that adds more logging assertion tools."""
+
+ @contextmanager
+ def assertNotLogs(self, logger=None, level=None, msg=None):
+ """
+ Asserts that no logs of `level` and higher were emitted by `logger`.
+
+ You can specify a specific `logger`, the minimum `logging` level we want to watch and a
+ custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the
+ recorded log records will be outputted with the `AssertionError` message. The context
+ manager does not yield a live `look` into the logging records, since we use this context
+ manager when we're testing under the assumption that no log records will be emitted.
+ """
+ if not isinstance(logger, logging.Logger):
+ logger = logging.getLogger(logger)
+
+ if level:
+ level = logging._nameToLevel.get(level, level)
+ else:
+ level = logging.INFO
+
+ handler = _CaptureLogHandler()
+ old_handlers = logger.handlers[:]
+ old_level = logger.level
+ old_propagate = logger.propagate
+
+ logger.handlers = [handler]
+ logger.setLevel(level)
+ logger.propagate = False
+
+ try:
+ yield
+ except Exception as exc:
+ raise exc
+ finally:
+ logger.handlers = old_handlers
+ logger.propagate = old_propagate
+ logger.setLevel(old_level)
+
+ if handler.records:
+ level_name = logging.getLevelName(level)
+ n_logs = len(handler.records)
+ base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n"
+ records = [str(record) for record in handler.records]
+ record_message = "\n".join(records)
+ standard_message = self._truncateMessage(base_message, record_message)
+ msg = self._formatMessage(msg, standard_message)
+ self.fail(msg)