aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/cogs/doc.py55
1 files changed, 42 insertions, 13 deletions
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py
index a1364dd8b..51323e64f 100644
--- a/bot/cogs/doc.py
+++ b/bot/cogs/doc.py
@@ -6,7 +6,7 @@ import textwrap
from collections import OrderedDict
from contextlib import suppress
from types import SimpleNamespace
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urljoin
import discord
@@ -265,7 +265,7 @@ class Doc(commands.Cog):
return None
if symbol_id == f"module-{symbol}":
- parsed_module = self.parse_module_symbol(symbol_heading, search_html)
+ parsed_module = self.parse_module_symbol(symbol_heading)
if parsed_module is None:
return None
else:
@@ -339,32 +339,29 @@ class Doc(commands.Cog):
return embed
@classmethod
- def parse_module_symbol(cls, heading: PageElement, html: str) -> Optional[Tuple[None, str]]:
+ def parse_module_symbol(cls, heading: PageElement) -> Optional[Tuple[None, str]]:
"""Get page content from the headerlink up to a table or a tag with its class in `SEARCH_END_TAG_ATTRS`."""
start_tag = heading.find("a", attrs={"class": "headerlink"})
if start_tag is None:
return None
- end_tag = start_tag.find_next(cls._match_end_tag)
- if end_tag is None:
+ description = cls.find_all_text_until_tag(start_tag, cls._match_end_tag)
+ if description is None:
return None
- description_start_index = html.find(str(start_tag.parent)) + len(str(start_tag.parent))
- description_end_index = html.find(str(end_tag))
- description = html[description_start_index:description_end_index]
-
return None, description
- @staticmethod
- def parse_symbol(heading: PageElement, html: str) -> Tuple[List[str], str]:
+ @classmethod
+ def parse_symbol(cls, heading: PageElement, html: str) -> Tuple[List[str], str]:
"""
Parse the signatures and description of a symbol.
Collects up to 3 signatures from dt tags and a description from their sibling dd tag.
"""
signatures = []
- description = str(heading.find_next_sibling("dd"))
- description_pos = html.find(description)
+ description_element = heading.find_next_sibling("dd")
+ description_pos = html.find(str(description_element))
+ description = "".join(cls.find_all_text_until_tag(description_element, ("dt",)))
for element in [heading] + heading.find_next_siblings("dt", limit=2):
signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text)
@@ -374,6 +371,38 @@ class Doc(commands.Cog):
return signatures, description
+ @staticmethod
+ def find_all_text_until_tag(
+ start_element: PageElement,
+ tag_filter: Union[Tuple[str], Callable[[Tag], bool]]
+ ) -> Optional[str]:
+ """
+ Get all text from <p> elements until a tag matching `tag_filter` is found, max 1000 elements searched.
+
+ `tag_filter` can be either a tuple of string names to check against,
+ or a filtering callable that's applied to the tags.
+ If no matching end tag is found, None is returned.
+ """
+ text = ""
+ element = start_element
+ for _ in range(1000):
+ if element is None:
+ break
+
+ element = element.find_next()
+ if element.name == "p":
+ text += str(element)
+
+ elif isinstance(tag_filter, tuple):
+ if element.name in tag_filter:
+ break
+ else:
+ if tag_filter(element):
+ break
+ else:
+ return None
+ return text
+
@commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True)
async def docs_group(self, ctx: commands.Context, *, symbol: str) -> None:
"""Lookup documentation for Python symbols."""