aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Andi Qu <[email protected]>2020-10-27 15:45:09 +0200
committerGravatar Andi Qu <[email protected]>2020-10-27 15:45:09 +0200
commit76afc563ac73f6b8d40194c15e28f42a9fe6be0f (patch)
tree2923d19e413d64d5eb63bf52c6067947ac140f3e
parentMade check for valid language easier to read (diff)
Moved global functions into the cog and got rid of unnecessary aiohttp sessions
-rw-r--r--bot/exts/info/code_snippets.py307
1 files changed, 158 insertions, 149 deletions
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py
index c53c28e8b..12eb692d4 100644
--- a/bot/exts/info/code_snippets.py
+++ b/bot/exts/info/code_snippets.py
@@ -2,7 +2,6 @@ import re
import textwrap
from urllib.parse import quote_plus
-from aiohttp import ClientSession
from discord import Message
from discord.ext.commands import Cog
@@ -10,150 +9,6 @@ from bot.bot import Bot
from bot.utils.messages import wait_for_deletion
-async def fetch_response(session: ClientSession, url: str, response_format: str, **kwargs) -> str:
- """Makes http requests using aiohttp."""
- async with session.get(url, **kwargs) as response:
- if response_format == 'text':
- return await response.text()
- elif response_format == 'json':
- return await response.json()
-
-
-def find_ref(path: str, refs: tuple) -> tuple:
- """Loops through all branches and tags to find the required ref."""
- # Base case: there is no slash in the branch name
- ref = path.split('/')[0]
- file_path = '/'.join(path.split('/')[1:])
- # In case there are slashes in the branch name, we loop through all branches and tags
- for possible_ref in refs:
- if path.startswith(possible_ref['name'] + '/'):
- ref = possible_ref['name']
- file_path = path[len(ref) + 1:]
- break
- return (ref, file_path)
-
-
-async def fetch_github_snippet(session: ClientSession, repo: str,
- path: str, start_line: str, end_line: str) -> str:
- """Fetches a snippet from a GitHub repo."""
- headers = {'Accept': 'application/vnd.github.v3.raw'}
-
- # Search the GitHub API for the specified branch
- branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers)
- tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers)
- refs = branches + tags
- ref, file_path = find_ref(path, refs)
-
- file_contents = await fetch_response(
- session,
- f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
- 'text',
- headers=headers,
- )
- return snippet_to_codeblock(file_contents, file_path, start_line, end_line)
-
-
-async def fetch_github_gist_snippet(session: ClientSession, gist_id: str, revision: str,
- file_path: str, start_line: str, end_line: str) -> str:
- """Fetches a snippet from a GitHub gist."""
- headers = {'Accept': 'application/vnd.github.v3.raw'}
-
- gist_json = await fetch_response(
- session,
- f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
- 'json',
- headers=headers,
- )
-
- # Check each file in the gist for the specified file
- for gist_file in gist_json['files']:
- if file_path == gist_file.lower().replace('.', '-'):
- file_contents = await fetch_response(
- session,
- gist_json['files'][gist_file]['raw_url'],
- 'text',
- )
- return snippet_to_codeblock(file_contents, gist_file, start_line, end_line)
- return ''
-
-
-async def fetch_gitlab_snippet(session: ClientSession, repo: str,
- path: str, start_line: str, end_line: str) -> str:
- """Fetches a snippet from a GitLab repo."""
- enc_repo = quote_plus(repo)
-
- # Searches the GitLab API for the specified branch
- branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json')
- tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json')
- refs = branches + tags
- ref, file_path = find_ref(path, refs)
- enc_ref = quote_plus(ref)
- enc_file_path = quote_plus(file_path)
-
- file_contents = await fetch_response(
- session,
- f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
- 'text',
- )
- return snippet_to_codeblock(file_contents, file_path, start_line, end_line)
-
-
-async def fetch_bitbucket_snippet(session: ClientSession, repo: str, ref: str,
- file_path: str, start_line: int, end_line: int) -> str:
- """Fetches a snippet from a BitBucket repo."""
- file_contents = await fetch_response(
- session,
- f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
- 'text',
- )
- return snippet_to_codeblock(file_contents, file_path, start_line, end_line)
-
-
-def snippet_to_codeblock(file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
- """
- Given the entire file contents and target lines, creates a code block.
-
- First, we split the file contents into a list of lines and then keep and join only the required
- ones together.
-
- We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent
- markdown injection.
-
- Finally, we surround the code with ``` characters.
- """
- # Parse start_line and end_line into integers
- if end_line is None:
- start_line = end_line = int(start_line)
- else:
- start_line = int(start_line)
- end_line = int(end_line)
-
- split_file_contents = file_contents.splitlines()
-
- # Make sure that the specified lines are in range
- if start_line > end_line:
- start_line, end_line = end_line, start_line
- if start_line > len(split_file_contents) or end_line < 1:
- return ''
- start_line = max(1, start_line)
- end_line = min(len(split_file_contents), end_line)
-
- # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
- required = '\n'.join(split_file_contents[start_line - 1:end_line])
- required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
-
- # Extracts the code language and checks whether it's a "valid" language
- language = file_path.split('/')[-1].split('.')[-1]
- trimmed_language = language.replace('-', '').replace('+', '').replace('_', '')
- is_valid_language = trimmed_language.isalnum()
- if not is_valid_language:
- language = ''
-
- if len(required) != 0:
- return f'```{language}\n{required}```\n'
- return ''
-
-
GITHUB_RE = re.compile(
r'https://github\.com/(?P<repo>.+?)/blob/(?P<path>.+/.+)'
r'#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
@@ -183,6 +38,160 @@ class CodeSnippets(Cog):
Matches each message against a regex and prints the contents of all matched snippets.
"""
+ async def _fetch_response(self, url: str, response_format: str, **kwargs) -> str:
+ """Makes http requests using aiohttp."""
+ async with self.bot.http_session.get(url, **kwargs) as response:
+ if response_format == 'text':
+ return await response.text()
+ elif response_format == 'json':
+ return await response.json()
+
+ def _find_ref(self, path: str, refs: tuple) -> tuple:
+ """Loops through all branches and tags to find the required ref."""
+ # Base case: there is no slash in the branch name
+ ref = path.split('/')[0]
+ file_path = '/'.join(path.split('/')[1:])
+ # In case there are slashes in the branch name, we loop through all branches and tags
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+ return (ref, file_path)
+
+ async def _fetch_github_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub repo."""
+ headers = {'Accept': 'application/vnd.github.v3.raw'}
+
+ # Search the GitHub API for the specified branch
+ branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers)
+ tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers)
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+
+ file_contents = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
+ 'text',
+ headers=headers,
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_github_gist_snippet(
+ self,
+ gist_id: str,
+ revision: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub gist."""
+ headers = {'Accept': 'application/vnd.github.v3.raw'}
+
+ gist_json = await self._fetch_response(
+ f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
+ 'json',
+ headers=headers,
+ )
+
+ # Check each file in the gist for the specified file
+ for gist_file in gist_json['files']:
+ if file_path == gist_file.lower().replace('.', '-'):
+ file_contents = await self._fetch_response(
+ gist_json['files'][gist_file]['raw_url'],
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line)
+ return ''
+
+ async def _fetch_gitlab_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitLab repo."""
+ enc_repo = quote_plus(repo)
+
+ # Searches the GitLab API for the specified branch
+ branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json')
+ tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json')
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+ enc_ref = quote_plus(ref)
+ enc_file_path = quote_plus(file_path)
+
+ file_contents = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_bitbucket_snippet(
+ self,
+ repo: str,
+ ref: str,
+ file_path: str,
+ start_line: int,
+ end_line: int
+ ) -> str:
+ """Fetches a snippet from a BitBucket repo."""
+ file_contents = await self._fetch_response(
+ f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
+ """
+ Given the entire file contents and target lines, creates a code block.
+
+ First, we split the file contents into a list of lines and then keep and join only the required
+ ones together.
+
+ We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent
+ markdown injection.
+
+ Finally, we surround the code with ``` characters.
+ """
+ # Parse start_line and end_line into integers
+ if end_line is None:
+ start_line = end_line = int(start_line)
+ else:
+ start_line = int(start_line)
+ end_line = int(end_line)
+
+ split_file_contents = file_contents.splitlines()
+
+ # Make sure that the specified lines are in range
+ if start_line > end_line:
+ start_line, end_line = end_line, start_line
+ if start_line > len(split_file_contents) or end_line < 1:
+ return ''
+ start_line = max(1, start_line)
+ end_line = min(len(split_file_contents), end_line)
+
+ # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
+ required = '\n'.join(split_file_contents[start_line - 1:end_line])
+ required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
+
+ # Extracts the code language and checks whether it's a "valid" language
+ language = file_path.split('/')[-1].split('.')[-1]
+ trimmed_language = language.replace('-', '').replace('+', '').replace('_', '')
+ is_valid_language = trimmed_language.isalnum()
+ if not is_valid_language:
+ language = ''
+
+ if len(required) != 0:
+ return f'```{language}\n{required}```\n'
+ return ''
+
def __init__(self, bot: Bot):
"""Initializes the cog's bot."""
self.bot = bot
@@ -199,16 +208,16 @@ class CodeSnippets(Cog):
message_to_send = ''
for gh in GITHUB_RE.finditer(message.content):
- message_to_send += await fetch_github_snippet(self.bot.http_session, **gh.groupdict())
+ message_to_send += await self._fetch_github_snippet(**gh.groupdict())
for gh_gist in GITHUB_GIST_RE.finditer(message.content):
- message_to_send += await fetch_github_gist_snippet(self.bot.http_session, **gh_gist.groupdict())
+ message_to_send += await self._fetch_github_gist_snippet(**gh_gist.groupdict())
for gl in GITLAB_RE.finditer(message.content):
- message_to_send += await fetch_gitlab_snippet(self.bot.http_session, **gl.groupdict())
+ message_to_send += await self._fetch_gitlab_snippet(**gl.groupdict())
for bb in BITBUCKET_RE.finditer(message.content):
- message_to_send += await fetch_bitbucket_snippet(self.bot.http_session, **bb.groupdict())
+ message_to_send += await self._fetch_bitbucket_snippet(**bb.groupdict())
if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15:
await message.edit(suppress=True)