diff options
| -rw-r--r-- | bot/exts/info/code_snippets.py | 307 | 
1 files changed, 158 insertions, 149 deletions
| diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py index c53c28e8b..12eb692d4 100644 --- a/bot/exts/info/code_snippets.py +++ b/bot/exts/info/code_snippets.py @@ -2,7 +2,6 @@ import re  import textwrap  from urllib.parse import quote_plus -from aiohttp import ClientSession  from discord import Message  from discord.ext.commands import Cog @@ -10,150 +9,6 @@ from bot.bot import Bot  from bot.utils.messages import wait_for_deletion -async def fetch_response(session: ClientSession, url: str, response_format: str, **kwargs) -> str: -    """Makes http requests using aiohttp.""" -    async with session.get(url, **kwargs) as response: -        if response_format == 'text': -            return await response.text() -        elif response_format == 'json': -            return await response.json() - - -def find_ref(path: str, refs: tuple) -> tuple: -    """Loops through all branches and tags to find the required ref.""" -    # Base case: there is no slash in the branch name -    ref = path.split('/')[0] -    file_path = '/'.join(path.split('/')[1:]) -    # In case there are slashes in the branch name, we loop through all branches and tags -    for possible_ref in refs: -        if path.startswith(possible_ref['name'] + '/'): -            ref = possible_ref['name'] -            file_path = path[len(ref) + 1:] -            break -    return (ref, file_path) - - -async def fetch_github_snippet(session: ClientSession, repo: str, -                               path: str, start_line: str, end_line: str) -> str: -    """Fetches a snippet from a GitHub repo.""" -    headers = {'Accept': 'application/vnd.github.v3.raw'} - -    # Search the GitHub API for the specified branch -    branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers) -    tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers) -    refs = branches + tags -    ref, file_path = find_ref(path, refs) - -    file_contents = await fetch_response( -        session, -        f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', -        'text', -        headers=headers, -    ) -    return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -async def fetch_github_gist_snippet(session: ClientSession, gist_id: str, revision: str, -                                    file_path: str, start_line: str, end_line: str) -> str: -    """Fetches a snippet from a GitHub gist.""" -    headers = {'Accept': 'application/vnd.github.v3.raw'} - -    gist_json = await fetch_response( -        session, -        f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', -        'json', -        headers=headers, -    ) - -    # Check each file in the gist for the specified file -    for gist_file in gist_json['files']: -        if file_path == gist_file.lower().replace('.', '-'): -            file_contents = await fetch_response( -                session, -                gist_json['files'][gist_file]['raw_url'], -                'text', -            ) -            return snippet_to_codeblock(file_contents, gist_file, start_line, end_line) -    return '' - - -async def fetch_gitlab_snippet(session: ClientSession, repo: str, -                               path: str, start_line: str, end_line: str) -> str: -    """Fetches a snippet from a GitLab repo.""" -    enc_repo = quote_plus(repo) - -    # Searches the GitLab API for the specified branch -    branches = await fetch_response(session, f'https://api.github.com/repos/{repo}/branches', 'json') -    tags = await fetch_response(session, f'https://api.github.com/repos/{repo}/tags', 'json') -    refs = branches + tags -    ref, file_path = find_ref(path, refs) -    enc_ref = quote_plus(ref) -    enc_file_path = quote_plus(file_path) - -    file_contents = await fetch_response( -        session, -        f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', -        'text', -    ) -    return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -async def fetch_bitbucket_snippet(session: ClientSession, repo: str, ref: str, -                                  file_path: str, start_line: int, end_line: int) -> str: -    """Fetches a snippet from a BitBucket repo.""" -    file_contents = await fetch_response( -        session, -        f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', -        'text', -    ) -    return snippet_to_codeblock(file_contents, file_path, start_line, end_line) - - -def snippet_to_codeblock(file_contents: str, file_path: str, start_line: str, end_line: str) -> str: -    """ -    Given the entire file contents and target lines, creates a code block. - -    First, we split the file contents into a list of lines and then keep and join only the required -    ones together. - -    We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent -    markdown injection. - -    Finally, we surround the code with ``` characters. -    """ -    # Parse start_line and end_line into integers -    if end_line is None: -        start_line = end_line = int(start_line) -    else: -        start_line = int(start_line) -        end_line = int(end_line) - -    split_file_contents = file_contents.splitlines() - -    # Make sure that the specified lines are in range -    if start_line > end_line: -        start_line, end_line = end_line, start_line -    if start_line > len(split_file_contents) or end_line < 1: -        return '' -    start_line = max(1, start_line) -    end_line = min(len(split_file_contents), end_line) - -    # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection -    required = '\n'.join(split_file_contents[start_line - 1:end_line]) -    required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') - -    # Extracts the code language and checks whether it's a "valid" language -    language = file_path.split('/')[-1].split('.')[-1] -    trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') -    is_valid_language = trimmed_language.isalnum() -    if not is_valid_language: -        language = '' - -    if len(required) != 0: -        return f'```{language}\n{required}```\n' -    return '' - -  GITHUB_RE = re.compile(      r'https://github\.com/(?P<repo>.+?)/blob/(?P<path>.+/.+)'      r'#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b' @@ -183,6 +38,160 @@ class CodeSnippets(Cog):      Matches each message against a regex and prints the contents of all matched snippets.      """ +    async def _fetch_response(self, url: str, response_format: str, **kwargs) -> str: +        """Makes http requests using aiohttp.""" +        async with self.bot.http_session.get(url, **kwargs) as response: +            if response_format == 'text': +                return await response.text() +            elif response_format == 'json': +                return await response.json() + +    def _find_ref(self, path: str, refs: tuple) -> tuple: +        """Loops through all branches and tags to find the required ref.""" +        # Base case: there is no slash in the branch name +        ref = path.split('/')[0] +        file_path = '/'.join(path.split('/')[1:]) +        # In case there are slashes in the branch name, we loop through all branches and tags +        for possible_ref in refs: +            if path.startswith(possible_ref['name'] + '/'): +                ref = possible_ref['name'] +                file_path = path[len(ref) + 1:] +                break +        return (ref, file_path) + +    async def _fetch_github_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub repo.""" +        headers = {'Accept': 'application/vnd.github.v3.raw'} + +        # Search the GitHub API for the specified branch +        branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers) +        tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers) +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) + +        file_contents = await self._fetch_response( +            f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', +            'text', +            headers=headers, +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_github_gist_snippet( +        self, +        gist_id: str, +        revision: str, +        file_path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub gist.""" +        headers = {'Accept': 'application/vnd.github.v3.raw'} + +        gist_json = await self._fetch_response( +            f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', +            'json', +            headers=headers, +        ) + +        # Check each file in the gist for the specified file +        for gist_file in gist_json['files']: +            if file_path == gist_file.lower().replace('.', '-'): +                file_contents = await self._fetch_response( +                    gist_json['files'][gist_file]['raw_url'], +                    'text', +                ) +                return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) +        return '' + +    async def _fetch_gitlab_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitLab repo.""" +        enc_repo = quote_plus(repo) + +        # Searches the GitLab API for the specified branch +        branches = await self._fetch_response(f'https://api.github.com/repos/{repo}/branches', 'json') +        tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json') +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) +        enc_ref = quote_plus(ref) +        enc_file_path = quote_plus(file_path) + +        file_contents = await self._fetch_response( +            f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_bitbucket_snippet( +        self, +        repo: str, +        ref: str, +        file_path: str, +        start_line: int, +        end_line: int +    ) -> str: +        """Fetches a snippet from a BitBucket repo.""" +        file_contents = await self._fetch_response( +            f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: +        """ +        Given the entire file contents and target lines, creates a code block. + +        First, we split the file contents into a list of lines and then keep and join only the required +        ones together. + +        We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent +        markdown injection. + +        Finally, we surround the code with ``` characters. +        """ +        # Parse start_line and end_line into integers +        if end_line is None: +            start_line = end_line = int(start_line) +        else: +            start_line = int(start_line) +            end_line = int(end_line) + +        split_file_contents = file_contents.splitlines() + +        # Make sure that the specified lines are in range +        if start_line > end_line: +            start_line, end_line = end_line, start_line +        if start_line > len(split_file_contents) or end_line < 1: +            return '' +        start_line = max(1, start_line) +        end_line = min(len(split_file_contents), end_line) + +        # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection +        required = '\n'.join(split_file_contents[start_line - 1:end_line]) +        required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + +        # Extracts the code language and checks whether it's a "valid" language +        language = file_path.split('/')[-1].split('.')[-1] +        trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') +        is_valid_language = trimmed_language.isalnum() +        if not is_valid_language: +            language = '' + +        if len(required) != 0: +            return f'```{language}\n{required}```\n' +        return '' +      def __init__(self, bot: Bot):          """Initializes the cog's bot."""          self.bot = bot @@ -199,16 +208,16 @@ class CodeSnippets(Cog):              message_to_send = ''              for gh in GITHUB_RE.finditer(message.content): -                message_to_send += await fetch_github_snippet(self.bot.http_session, **gh.groupdict()) +                message_to_send += await self._fetch_github_snippet(**gh.groupdict())              for gh_gist in GITHUB_GIST_RE.finditer(message.content): -                message_to_send += await fetch_github_gist_snippet(self.bot.http_session, **gh_gist.groupdict()) +                message_to_send += await self._fetch_github_gist_snippet(**gh_gist.groupdict())              for gl in GITLAB_RE.finditer(message.content): -                message_to_send += await fetch_gitlab_snippet(self.bot.http_session, **gl.groupdict()) +                message_to_send += await self._fetch_gitlab_snippet(**gl.groupdict())              for bb in BITBUCKET_RE.finditer(message.content): -                message_to_send += await fetch_bitbucket_snippet(self.bot.http_session, **bb.groupdict()) +                message_to_send += await self._fetch_bitbucket_snippet(**bb.groupdict())              if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15:                  await message.edit(suppress=True) | 
