aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Joe Banks <[email protected]>2021-04-27 11:08:37 +0100
committerGravatar GitHub <[email protected]>2021-04-27 11:08:37 +0100
commit3c0da0b5718780ab1afd611d53ff520c0b727709 (patch)
treeb50da96c6c01f9d21fceac7e5f67455ea209a7b6
parentMerge pull request #1426 from python-discord/feat/1365/add-bot-badges-to-user (diff)
parentMerge branch 'main' into master (diff)
Merge pull request #1028 from dolphingarlic/master
-rw-r--r--bot/exts/info/code_snippets.py247
1 files changed, 247 insertions, 0 deletions
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py
new file mode 100644
index 000000000..c20115830
--- /dev/null
+++ b/bot/exts/info/code_snippets.py
@@ -0,0 +1,247 @@
+import logging
+import re
+import textwrap
+from urllib.parse import quote_plus
+
+from aiohttp import ClientResponseError
+from discord import Message
+from discord.ext.commands import Cog
+
+from bot.bot import Bot
+from bot.utils.messages import wait_for_deletion
+
+log = logging.getLogger(__name__)
+
+GITHUB_RE = re.compile(
+ r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/'
+ r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_GIST_RE = re.compile(
+ r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*'
+ r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?'
+ r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'}
+
+GITLAB_RE = re.compile(
+ r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)'
+ r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)'
+)
+
+BITBUCKET_RE = re.compile(
+ r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)'
+ r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)'
+)
+
+
+class CodeSnippets(Cog):
+ """
+ Cog that parses and sends code snippets to Discord.
+
+ Matches each message against a regex and prints the contents of all matched snippets.
+ """
+
+ async def _fetch_response(self, url: str, response_format: str, **kwargs) -> str:
+ """Makes http requests using aiohttp."""
+ try:
+ async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response:
+ if response_format == 'text':
+ return await response.text()
+ elif response_format == 'json':
+ return await response.json()
+ except ClientResponseError as error:
+ log.error(f'Failed to fetch code snippet from {url}. HTTP Status: {error.status}. Message: {str(error)}.')
+
+ def _find_ref(self, path: str, refs: tuple) -> tuple:
+ """Loops through all branches and tags to find the required ref."""
+ # Base case: there is no slash in the branch name
+ ref, file_path = path.split('/', 1)
+ # In case there are slashes in the branch name, we loop through all branches and tags
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+ return (ref, file_path)
+
+ async def _fetch_github_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub repo."""
+ # Search the GitHub API for the specified branch
+ branches = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/branches',
+ 'json',
+ headers=GITHUB_HEADERS
+ )
+ tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS)
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+
+ file_contents = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
+ 'text',
+ headers=GITHUB_HEADERS,
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_github_gist_snippet(
+ self,
+ gist_id: str,
+ revision: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub gist."""
+ gist_json = await self._fetch_response(
+ f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
+ 'json',
+ headers=GITHUB_HEADERS,
+ )
+
+ # Check each file in the gist for the specified file
+ for gist_file in gist_json['files']:
+ if file_path == gist_file.lower().replace('.', '-'):
+ file_contents = await self._fetch_response(
+ gist_json['files'][gist_file]['raw_url'],
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line)
+ return ''
+
+ async def _fetch_gitlab_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitLab repo."""
+ enc_repo = quote_plus(repo)
+
+ # Searches the GitLab API for the specified branch
+ branches = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches',
+ 'json'
+ )
+ tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json')
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+ enc_ref = quote_plus(ref)
+ enc_file_path = quote_plus(file_path)
+
+ file_contents = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_bitbucket_snippet(
+ self,
+ repo: str,
+ ref: str,
+ file_path: str,
+ start_line: int,
+ end_line: int
+ ) -> str:
+ """Fetches a snippet from a BitBucket repo."""
+ file_contents = await self._fetch_response(
+ f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
+ """
+ Given the entire file contents and target lines, creates a code block.
+
+ First, we split the file contents into a list of lines and then keep and join only the required
+ ones together.
+
+ We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent
+ markdown injection.
+
+ Finally, we surround the code with ``` characters.
+ """
+ # Parse start_line and end_line into integers
+ if end_line is None:
+ start_line = end_line = int(start_line)
+ else:
+ start_line = int(start_line)
+ end_line = int(end_line)
+
+ split_file_contents = file_contents.splitlines()
+
+ # Make sure that the specified lines are in range
+ if start_line > end_line:
+ start_line, end_line = end_line, start_line
+ if start_line > len(split_file_contents) or end_line < 1:
+ return ''
+ start_line = max(1, start_line)
+ end_line = min(len(split_file_contents), end_line)
+
+ # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
+ required = '\n'.join(split_file_contents[start_line - 1:end_line])
+ required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
+
+ # Extracts the code language and checks whether it's a "valid" language
+ language = file_path.split('/')[-1].split('.')[-1]
+ trimmed_language = language.replace('-', '').replace('+', '').replace('_', '')
+ is_valid_language = trimmed_language.isalnum()
+ if not is_valid_language:
+ language = ''
+
+ # Adds a label showing the file path to the snippet
+ if start_line == end_line:
+ ret = f'`{file_path}` line {start_line}\n'
+ else:
+ ret = f'`{file_path}` lines {start_line} to {end_line}\n'
+
+ if len(required) != 0:
+ return f'{ret}```{language}\n{required}```'
+ # Returns an empty codeblock if the snippet is empty
+ return f'{ret}``` ```'
+
+ def __init__(self, bot: Bot):
+ """Initializes the cog's bot."""
+ self.bot = bot
+
+ self.pattern_handlers = [
+ (GITHUB_RE, self._fetch_github_snippet),
+ (GITHUB_GIST_RE, self._fetch_github_gist_snippet),
+ (GITLAB_RE, self._fetch_gitlab_snippet),
+ (BITBUCKET_RE, self._fetch_bitbucket_snippet)
+ ]
+
+ @Cog.listener()
+ async def on_message(self, message: Message) -> None:
+ """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
+ if not message.author.bot:
+ all_snippets = []
+
+ for pattern, handler in self.pattern_handlers:
+ for match in pattern.finditer(message.content):
+ snippet = await handler(**match.groupdict())
+ all_snippets.append((match.start(), snippet))
+
+ # Sorts the list of snippets by their match index and joins them into a single message
+ message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets)))
+
+ if 0 < len(message_to_send) <= 2000 and len(all_snippets) <= 15:
+ await message.edit(suppress=True)
+ await wait_for_deletion(
+ await message.channel.send(message_to_send),
+ (message.author.id,)
+ )
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeSnippets cog."""
+ bot.add_cog(CodeSnippets(bot))