diff options
| -rw-r--r-- | bot/cogs/doc.py | 55 |
1 files changed, 42 insertions, 13 deletions
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index a1364dd8b..51323e64f 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -6,7 +6,7 @@ import textwrap from collections import OrderedDict from contextlib import suppress from types import SimpleNamespace -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import urljoin import discord @@ -265,7 +265,7 @@ class Doc(commands.Cog): return None if symbol_id == f"module-{symbol}": - parsed_module = self.parse_module_symbol(symbol_heading, search_html) + parsed_module = self.parse_module_symbol(symbol_heading) if parsed_module is None: return None else: @@ -339,32 +339,29 @@ class Doc(commands.Cog): return embed @classmethod - def parse_module_symbol(cls, heading: PageElement, html: str) -> Optional[Tuple[None, str]]: + def parse_module_symbol(cls, heading: PageElement) -> Optional[Tuple[None, str]]: """Get page content from the headerlink up to a table or a tag with its class in `SEARCH_END_TAG_ATTRS`.""" start_tag = heading.find("a", attrs={"class": "headerlink"}) if start_tag is None: return None - end_tag = start_tag.find_next(cls._match_end_tag) - if end_tag is None: + description = cls.find_all_text_until_tag(start_tag, cls._match_end_tag) + if description is None: return None - description_start_index = html.find(str(start_tag.parent)) + len(str(start_tag.parent)) - description_end_index = html.find(str(end_tag)) - description = html[description_start_index:description_end_index] - return None, description - @staticmethod - def parse_symbol(heading: PageElement, html: str) -> Tuple[List[str], str]: + @classmethod + def parse_symbol(cls, heading: PageElement, html: str) -> Tuple[List[str], str]: """ Parse the signatures and description of a symbol. Collects up to 3 signatures from dt tags and a description from their sibling dd tag. """ signatures = [] - description = str(heading.find_next_sibling("dd")) - description_pos = html.find(description) + description_element = heading.find_next_sibling("dd") + description_pos = html.find(str(description_element)) + description = "".join(cls.find_all_text_until_tag(description_element, ("dt",))) for element in [heading] + heading.find_next_siblings("dt", limit=2): signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) @@ -374,6 +371,38 @@ class Doc(commands.Cog): return signatures, description + @staticmethod + def find_all_text_until_tag( + start_element: PageElement, + tag_filter: Union[Tuple[str], Callable[[Tag], bool]] + ) -> Optional[str]: + """ + Get all text from <p> elements until a tag matching `tag_filter` is found, max 1000 elements searched. + + `tag_filter` can be either a tuple of string names to check against, + or a filtering callable that's applied to the tags. + If no matching end tag is found, None is returned. + """ + text = "" + element = start_element + for _ in range(1000): + if element is None: + break + + element = element.find_next() + if element.name == "p": + text += str(element) + + elif isinstance(tag_filter, tuple): + if element.name in tag_filter: + break + else: + if tag_filter(element): + break + else: + return None + return text + @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) async def docs_group(self, ctx: commands.Context, *, symbol: str) -> None: """Lookup documentation for Python symbols.""" |