diff options
| author | 2020-10-27 15:45:09 +0200 | |
|---|---|---|
| committer | 2020-10-27 15:45:09 +0200 | |
| commit | 76afc563ac73f6b8d40194c15e28f42a9fe6be0f (patch) | |
| tree | 2923d19e413d64d5eb63bf52c6067947ac140f3e | |
| parent | Made check for valid language easier to read (diff) | |
Moved global functions into the cog and got rid of unnecessary aiohttp sessions
| -rw-r--r-- | bot/exts/info/code_snippets.py | 307 |
1 files changed, 158 insertions, 149 deletions
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index c53c28e8b..12eb692d4 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -2,7 +2,6 @@ import re import textwrap from urllib.parse import quote_plus -from aiohttp import ClientSession from discord import Message from discord.ext.commands import Cog @@ -10,150 +9,6 @@ from bot.bot import Bot from bot.utils.messages import wait_for_deletion -async def fetch_response(session: ClientSession, url: str, response_format: str, **kwargs) -> str: - """Makes http requests using aiohttp.""" - async with session.get(url, **kwargs) as response: - if response_format == 'text': - return await response.text() - elif response_format == 'json': - return await response.json() - - -def find_ref(path: str, refs: tuple) -> tuple: - """Loops through all branches and tags to find the required ref.""" - # Base case: there is no slash in the branch name - ref = path.split('/')[0] - file_path = '/'.join(path.split('/')[1:]) - # In case there are slashes in the branch name, we loop through all branches and tags - for possible_ref in refs: - if path.startswith(possible_ref['name'] + '/'): - ref = possible_ref['name'] - file_path = path[len(ref) + 1:] - break - return (ref, file_path) - - -async def fetch_github_snippet(session: ClientSession, repo: str, - path: str, start_line: str, end_line: str) -> str: - """Fetches a snippet from a GitHub repo.""" - headers = {'Accept': 'application/vnd.github.v3.raw'} - - # Search the GitHub API for the specified branch - branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers) - tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers) - refs = branches + tags - ref, file_path = find_ref(path, refs) - - file_contents = await fetch_response( - session, - f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', - 'text', - headers=headers, - ) - return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -async def fetch_github_gist_snippet(session: ClientSession, gist_id: str, revision: str, - file_path: str, start_line: str, end_line: str) -> str: - """Fetches a snippet from a GitHub gist.""" - headers = {'Accept': 'application/vnd.github.v3.raw'} - - gist_json = await fetch_response( - session, - f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', - 'json', - headers=headers, - ) - - # Check each file in the gist for the specified file - for gist_file in gist_json['files']: - if file_path == gist_file.lower().replace('.', '-'): - file_contents = await fetch_response( - session, - gist_json['files'][gist_file]['raw_url'], - 'text', - ) - return snippet_to_codeblock(file_contents, gist_file, start_line, end_line) - return '' - - -async def fetch_gitlab_snippet(session: ClientSession, repo: str, - path: str, start_line: str, end_line: str) -> str: - """Fetches a snippet from a GitLab repo.""" - enc_repo = quote_plus(repo) - - # Searches the GitLab API for the specified branch - branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json') - tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json') - refs = branches + tags - ref, file_path = find_ref(path, refs) - enc_ref = quote_plus(ref) - enc_file_path = quote_plus(file_path) - - file_contents = await fetch_response( - session, - f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', - 'text', - ) - return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -async def fetch_bitbucket_snippet(session: ClientSession, repo: str, ref: str, - file_path: str, start_line: int, end_line: int) -> str: - """Fetches a snippet from a BitBucket repo.""" - file_contents = await fetch_response( - session, - f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', - 'text', - ) - return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -def snippet_to_codeblock(file_contents: str, file_path: str, start_line: str, end_line: str) -> str: - """ - Given the entire file contents and target lines, creates a code block. - - First, we split the file contents into a list of lines and then keep and join only the required - ones together. - - We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent - markdown injection. - - Finally, we surround the code with ``` characters. - """ - # Parse start_line and end_line into integers - if end_line is None: - start_line = end_line = int(start_line) - else: - start_line = int(start_line) - end_line = int(end_line) - - split_file_contents = file_contents.splitlines() - - # Make sure that the specified lines are in range - if start_line > end_line: - start_line, end_line = end_line, start_line - if start_line > len(split_file_contents) or end_line < 1: - return '' - start_line = max(1, start_line) - end_line = min(len(split_file_contents), end_line) - - # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection - required = '\n'.join(split_file_contents[start_line - 1:end_line]) - required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') - - # Extracts the code language and checks whether it's a "valid" language - language = file_path.split('/')[-1].split('.')[-1] - trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') - is_valid_language = trimmed_language.isalnum() - if not is_valid_language: - language = '' - - if len(required) != 0: - return f'```{language}\n{required}```\n' - return '' - - GITHUB_RE = re.compile( r'https://github\.com/(?P<repo>.+?)/blob/(?P<path>.+/.+)' r'#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b' @@ -183,6 +38,160 @@ class CodeSnippets(Cog): Matches each message against a regex and prints the contents of all matched snippets. """ + async def _fetch_response(self, url: str, response_format: str, **kwargs) -> str: + """Makes http requests using aiohttp.""" + async with self.bot.http_session.get(url, **kwargs) as response: + if response_format == 'text': + return await response.text() + elif response_format == 'json': + return await response.json() + + def _find_ref(self, path: str, refs: tuple) -> tuple: + """Loops through all branches and tags to find the required ref.""" + # Base case: there is no slash in the branch name + ref = path.split('/')[0] + file_path = '/'.join(path.split('/')[1:]) + # In case there are slashes in the branch name, we loop through all branches and tags + for possible_ref in refs: + if path.startswith(possible_ref['name'] + '/'): + ref = possible_ref['name'] + file_path = path[len(ref) + 1:] + break + return (ref, file_path) + + async def _fetch_github_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub repo.""" + headers = {'Accept': 'application/vnd.github.v3.raw'} + + # Search the GitHub API for the specified branch + branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers) + tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers) + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + + file_contents = await self._fetch_response( + f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', + 'text', + headers=headers, + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_github_gist_snippet( + self, + gist_id: str, + revision: str, + file_path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitHub gist.""" + headers = {'Accept': 'application/vnd.github.v3.raw'} + + gist_json = await self._fetch_response( + f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', + 'json', + headers=headers, + ) + + # Check each file in the gist for the specified file + for gist_file in gist_json['files']: + if file_path == gist_file.lower().replace('.', '-'): + file_contents = await self._fetch_response( + gist_json['files'][gist_file]['raw_url'], + 'text', + ) + return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) + return '' + + async def _fetch_gitlab_snippet( + self, + repo: str, + path: str, + start_line: str, + end_line: str + ) -> str: + """Fetches a snippet from a GitLab repo.""" + enc_repo = quote_plus(repo) + + # Searches the GitLab API for the specified branch + branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json') + tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json') + refs = branches + tags + ref, file_path = self._find_ref(path, refs) + enc_ref = quote_plus(ref) + enc_file_path = quote_plus(file_path) + + file_contents = await self._fetch_response( + f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + async def _fetch_bitbucket_snippet( + self, + repo: str, + ref: str, + file_path: str, + start_line: int, + end_line: int + ) -> str: + """Fetches a snippet from a BitBucket repo.""" + file_contents = await self._fetch_response( + f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', + 'text', + ) + return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + + def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: + """ + Given the entire file contents and target lines, creates a code block. + + First, we split the file contents into a list of lines and then keep and join only the required + ones together. + + We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent + markdown injection. + + Finally, we surround the code with ``` characters. + """ + # Parse start_line and end_line into integers + if end_line is None: + start_line = end_line = int(start_line) + else: + start_line = int(start_line) + end_line = int(end_line) + + split_file_contents = file_contents.splitlines() + + # Make sure that the specified lines are in range + if start_line > end_line: + start_line, end_line = end_line, start_line + if start_line > len(split_file_contents) or end_line < 1: + return '' + start_line = max(1, start_line) + end_line = min(len(split_file_contents), end_line) + + # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection + required = '\n'.join(split_file_contents[start_line - 1:end_line]) + required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + + # Extracts the code language and checks whether it's a "valid" language + language = file_path.split('/')[-1].split('.')[-1] + trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') + is_valid_language = trimmed_language.isalnum() + if not is_valid_language: + language = '' + + if len(required) != 0: + return f'```{language}\n{required}```\n' + return '' + def __init__(self, bot: Bot): """Initializes the cog's bot.""" self.bot = bot @@ -199,16 +208,16 @@ class CodeSnippets(Cog): message_to_send = '' for gh in GITHUB_RE.finditer(message.content): - message_to_send += await fetch_github_snippet(self.bot.http_session, **gh.groupdict()) + message_to_send += await self._fetch_github_snippet(**gh.groupdict()) for gh_gist in GITHUB_GIST_RE.finditer(message.content): - message_to_send += await fetch_github_gist_snippet(self.bot.http_session, **gh_gist.groupdict()) + message_to_send += await self._fetch_github_gist_snippet(**gh_gist.groupdict()) for gl in GITLAB_RE.finditer(message.content): - message_to_send += await fetch_gitlab_snippet(self.bot.http_session, **gl.groupdict()) + message_to_send += await self._fetch_gitlab_snippet(**gl.groupdict()) for bb in BITBUCKET_RE.finditer(message.content): - message_to_send += await fetch_bitbucket_snippet(self.bot.http_session, **bb.groupdict()) + message_to_send += await self._fetch_bitbucket_snippet(**bb.groupdict()) if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: await message.edit(suppress=True) |