diff options
author | 2022-06-14 23:02:10 +0200 | |
---|---|---|
committer | 2022-06-15 00:02:08 +0200 | |
commit | 83f8b2e8ec746232891e314eb33dd8f6ceb9c1af (patch) | |
tree | 38c5e0af94359e9cc31be31160242ff99425b13d /docs | |
parent | Add decorator to block duplicate command invocations in a channel (diff) |
Check assignments nested in ifs when searching for symbol definition
Diffstat (limited to 'docs')
-rw-r--r-- | docs/utils.py | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/docs/utils.py b/docs/utils.py index bb8074ba..5be5292b 100644 --- a/docs/utils.py +++ b/docs/utils.py @@ -71,21 +71,11 @@ def linkcode_resolve(repo_link: str, domain: str, info: dict[str, str]) -> typin while isinstance(source.body[0], ast.ClassDef): source = source.body[0] - for ast_obj in source.body: - if isinstance(ast_obj, ast.Assign): - names = [] - for target in ast_obj.targets: - if isinstance(target, ast.Tuple): - names.extend([name.id for name in target.elts]) - else: - names.append(target.id) - - if symbol_name in names: - start, end = ast_obj.lineno, ast_obj.end_lineno - break - else: + pos = _global_assign_pos(source, symbol_name) + if pos is None: raise Exception(f"Could not find symbol `{symbol_name}` in {module.__name__}.") - + else: + start, end = pos _, offset = inspect.getsourcelines(symbol[-2]) if offset != 0: offset -= 1 @@ -107,6 +97,35 @@ def linkcode_resolve(repo_link: str, domain: str, info: dict[str, str]) -> typin return url +class NodeWithBody(typing.Protocol): + """An AST node with the body attribute.""" + body: list[ast.AST] + + +def _global_assign_pos(ast_: NodeWithBody, name: str) -> typing.Union[tuple[int, int], None]: + """ + Find the first instance where the `name` global is defined in `ast_`. + + Top level assignments, and assignments nested in top level ifs are checked. + """ + for ast_obj in ast_.body: + if isinstance(ast_obj, ast.Assign): + names = [] + for target in ast_obj.targets: + if isinstance(target, ast.Tuple): + names.extend([name.id for name in target.elts]) + else: + names.append(target.id) + + if name in names: + return ast_obj.lineno, ast_obj.end_lineno + + elif isinstance(ast_obj, ast.If): + pos_in_if = _global_assign_pos(ast_obj, name) + if pos_in_if is not None: + return pos_in_if + + def cleanup() -> None: """Remove unneeded autogenerated doc files, and clean up others.""" included = __get_included() |