diff options
author | 2021-04-27 11:08:37 +0100 | |
---|---|---|
committer | 2021-04-27 11:08:37 +0100 | |
commit | 3c0da0b5718780ab1afd611d53ff520c0b727709 (patch) | |
tree | b50da96c6c01f9d21fceac7e5f67455ea209a7b6 | |
parent | Merge pull request #1426 from python-discord/feat/1365/add-bot-badges-to-user (diff) | |
parent | Merge branch 'main' into master (diff) |
Merge pull request #1028 from dolphingarlic/master
-rw-r--r-- | bot/exts/info/code_snippets.py | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py new file mode 100644 index 000000000..c20115830 --- /dev/null +++ b/bot/exts/info/code_snippets.py @@ -0,0 +1,247 @@ +import logging +import re +import textwrap +from urllib.parse import quote_plus + +from aiohttp import ClientResponseError +from discord import Message +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +GITHUB_RE = re.compile( + r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/' + r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_GIST_RE = re.compile( + r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*' + r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?' + r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'} + +GITLAB_RE = re.compile( + r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)' + r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)' +) + +BITBUCKET_RE = re.compile( + r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)' + r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)' +) + + +class CodeSnippets(Cog): + """ + Cog that parses and sends code snippets to Discord. + + Matches each message against a regex and prints the contents of all matched snippets. + """ + + async def _fetch_response(self, url: str, response_format: str, **kwargs) -> str: + """Makes http requests using aiohttp.""" + try: + async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response: + if response_format == 'text': + return await response.text() + elif response_format == 'json': + return await response.json() + except ClientResponseError as error: + log.error(f'Failed to fetch code snippet from {url}. HTTP Status: {error.status}. Message: {str(error)}.') + + def _find_ref(self, path: str, refs: tuple) -> tuple: + """Loops through all branches and tags to find the required ref.""" + # Base case: there is no slash in the branch name + ref, file_path = path.split('/', 1) + # In case there are slashes in the branch name, we loop through all branches and tags + for possible_ref in refs: + if path.startswith(possible_ref['name'] + '/'): + ref = possible_ref['name'] + file_path = path[len(ref) + 1:] + break + return (ref, file_path) + + async def _fetch_github_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub repo.""" + # Search the GitHub API for the specified branch + branches = await self._fetch_response( + f'https://api.github.com/repos/{repo}/branches', + 'json', + headers=GITHUB_HEADERS + ) + tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS) + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + + file_contents = await self._fetch_response( + f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', + 'text', + headers=GITHUB_HEADERS, + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_github_gist_snippet( + self, + gist_id: str, + revision: str, + file_path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub gist.""" + gist_json = await self._fetch_response( + f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', + 'json', + headers=GITHUB_HEADERS, + ) + + # Check each file in the gist for the specified file + for gist_file in gist_json['files']: + if file_path == gist_file.lower().replace('.', '-'): + file_contents = await self._fetch_response( + gist_json['files'][gist_file]['raw_url'], + 'text', + ) + return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) + return '' + + async def _fetch_gitlab_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitLab repo.""" + enc_repo = quote_plus(repo) + + # Searches the GitLab API for the specified branch + branches = await self._fetch_response( + f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches', + 'json' + ) + tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json') + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + enc_ref = quote_plus(ref) + enc_file_path = quote_plus(file_path) + + file_contents = await self._fetch_response( + f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_bitbucket_snippet( + self, + repo: str, + ref: str, + file_path: str, + start_line: int, + end_line: int + ) -> str: + """Fetches a snippet from a BitBucket repo.""" + file_contents = await self._fetch_response( + f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: + """ + Given the entire file contents and target lines, creates a code block. + + First, we split the file contents into a list of lines and then keep and join only the required + ones together. + + We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent + markdown injection. + + Finally, we surround the code with ``` characters. + """ + # Parse start_line and end_line into integers + if end_line is None: + start_line = end_line = int(start_line) + else: + start_line = int(start_line) + end_line = int(end_line) + + split_file_contents = file_contents.splitlines() + + # Make sure that the specified lines are in range + if start_line > end_line: + start_line, end_line = end_line, start_line + if start_line > len(split_file_contents) or end_line < 1: + return '' + start_line = max(1, start_line) + end_line = min(len(split_file_contents), end_line) + + # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection + required = '\n'.join(split_file_contents[start_line - 1:end_line]) + required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + + # Extracts the code language and checks whether it's a "valid" language + language = file_path.split('/')[-1].split('.')[-1] + trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') + is_valid_language = trimmed_language.isalnum() + if not is_valid_language: + language = '' + + # Adds a label showing the file path to the snippet + if start_line == end_line: + ret = f'`{file_path}` line {start_line}\n' + else: + ret = f'`{file_path}` lines {start_line} to {end_line}\n' + + if len(required) != 0: + return f'{ret}```{language}\n{required}```' + # Returns an empty codeblock if the snippet is empty + return f'{ret}``` ```' + + def __init__(self, bot: Bot): + """Initializes the cog's bot.""" + self.bot = bot + + self.pattern_handlers = [ + (GITHUB_RE, self._fetch_github_snippet), + (GITHUB_GIST_RE, self._fetch_github_gist_snippet), + (GITLAB_RE, self._fetch_gitlab_snippet), + (BITBUCKET_RE, self._fetch_bitbucket_snippet) + ] + + @Cog.listener() + async def on_message(self, message: Message) -> None: + """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" + if not message.author.bot: + all_snippets = [] + + for pattern, handler in self.pattern_handlers: + for match in pattern.finditer(message.content): + snippet = await handler(**match.groupdict()) + all_snippets.append((match.start(), snippet)) + + # Sorts the list of snippets by their match index and joins them into a single message + message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets))) + + if 0 < len(message_to_send) <= 2000 and len(all_snippets) <= 15: + await message.edit(suppress=True) + await wait_for_deletion( + await message.channel.send(message_to_send), + (message.author.id,) + ) + + +def setup(bot: Bot) -> None: + """Load the CodeSnippets cog.""" + bot.add_cog(CodeSnippets(bot)) |