aboutsummaryrefslogtreecommitdiffstats
path: root/tests/_autospec.py
blob: 6f990a5803a7a7e301f825bfbf754737862b1e8b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import contextlib
import functools
import pkgutil
import unittest.mock
from collections.abc import Callable


@functools.wraps(unittest.mock._patch.decoration_helper)
@contextlib.contextmanager
def _decoration_helper(self, patched, args, keywargs):
    """Skips adding patchings as args if their `dont_pass` attribute is True."""
    # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added.
    extra_args = []
    with contextlib.ExitStack() as exit_stack:
        for patching in patched.patchings:
            arg = exit_stack.enter_context(patching)
            if not getattr(patching, "dont_pass", False):
                # Only add the patching as an arg if dont_pass is False.
                if patching.attribute_name is not None:
                    keywargs.update(arg)
                elif patching.new is unittest.mock.DEFAULT:
                    extra_args.append(arg)

        args += tuple(extra_args)
        yield args, keywargs


@functools.wraps(unittest.mock._patch.copy)
def _copy(self):
    """Copy the `dont_pass` attribute along with the standard copy operation."""
    patcher_copy = _copy.original(self)
    patcher_copy.dont_pass = getattr(self, "dont_pass", False)
    return patcher_copy


# Monkey-patch the patcher class :)
_copy.original = unittest.mock._patch.copy
unittest.mock._patch.copy = _copy
unittest.mock._patch.decoration_helper = _decoration_helper


def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable:
    """
    Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.

    If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object.
    """
    # Caller's kwargs should take priority and overwrite the defaults.
    kwargs = dict(spec_set=True, autospec=True)
    kwargs.update(patch_kwargs)

    # Import the target if it's a string.
    # This is to support both object and string targets like patch.multiple.
    if type(target) is str:
        target = pkgutil.resolve_name(target)

    def decorator(func):
        for attribute in attributes:
            patcher = unittest.mock.patch.object(target, attribute, **kwargs)
            if not pass_mocks:
                # A custom attribute to keep track of which patchings should be skipped.
                patcher.dont_pass = True
            func = patcher(func)
        return func
    return decorator