aboutsummaryrefslogtreecommitdiffstats
path: root/bot/cogs/print_snippets.py
blob: 5c83cd62b6891570c7d210f0b61ffc4ad0fc2327 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import re
import textwrap

import aiohttp
from discord import Message
from discord.ext.commands import Cog

from bot.bot import Bot


async def fetch_http(session: aiohttp.ClientSession, url: str, response_format: str, **kwargs) -> str:
    """Uses aiohttp to make http GET requests."""
    async with session.get(url, **kwargs) as response:
        if response_format == 'text':
            return await response.text()
        elif response_format == 'json':
            return await response.json()


async def revert_to_orig(d: dict) -> dict:
    """Replace URL Encoded values back to their original."""
    for obj in d:
        if d[obj] is not None:
            d[obj] = d[obj].replace('%2F', '/').replace('%2E', '.')


async def orig_to_encode(d: dict) -> dict:
    """Encode URL Parameters."""
    for obj in d:
        if d[obj] is not None:
            d[obj] = d[obj].replace('/', '%2F').replace('.', '%2E')


async def snippet_to_embed(d: dict, file_contents: str) -> str:
    """Given a regex groupdict and file contents, creates a code block."""
    if d['end_line']:
        start_line = int(d['start_line'])
        end_line = int(d['end_line'])
    else:
        start_line = end_line = int(d['start_line'])

    split_file_contents = file_contents.split('\n')

    if start_line > end_line:
        start_line, end_line = end_line, start_line
    if start_line > len(split_file_contents) or end_line < 1:
        return ''
    start_line = max(1, start_line)
    end_line = min(len(split_file_contents), end_line)

    required = '\n'.join(split_file_contents[start_line - 1:end_line])
    required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')

    language = d['file_path'].split('/')[-1].split('.')[-1]
    if not language.replace('-', '').replace('+', '').replace('_', '').isalnum():
        language = ''

    if len(required) != 0:
        return f'```{language}\n{required}```\n'
    return '``` ```\n'


GITHUB_RE = re.compile(
    r'https://github\.com/(?P<repo>.+?)/blob/(?P<branch>.+?)/'
    + r'(?P<file_path>.+?)#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
)

GITHUB_GIST_RE = re.compile(
    r'https://gist\.github\.com/([^/]*)/(?P<gist_id>[0-9a-zA-Z]+)/*'
    + r'(?P<revision>[0-9a-zA-Z]*)/*#file-(?P<file_path>.+?)'
    + r'-L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
)

GITLAB_RE = re.compile(
    r'https://gitlab\.com/(?P<repo>.+?)/\-/blob/(?P<branch>.+?)/'
    + r'(?P<file_path>.+?)#L(?P<start_line>\d+)([-~](?P<end_line>\d+))?\b'
)

BITBUCKET_RE = re.compile(
    r'https://bitbucket\.org/(?P<repo>.+?)/src/(?P<branch>.+?)/'
    + r'(?P<file_path>.+?)#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?\b'
)


class PrintSnippets(Cog):
    """
    Cog that prints out snippets to Discord.

    Matches each message against a regex and prints the contents of all matched snippets.
    """

    def __init__(self, bot: Bot):
        """Initializes the cog's bot"""
        self.bot = bot
        self.session = aiohttp.ClientSession()

    @Cog.listener()
    async def on_message(self, message: Message) -> None:
        """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
        gh_match = GITHUB_RE.search(message.content)
        gh_gist_match = GITHUB_GIST_RE.search(message.content)
        gl_match = GITLAB_RE.search(message.content)
        bb_match = BITBUCKET_RE.search(message.content)

        if (gh_match or gh_gist_match or gl_match or bb_match) and not message.author.bot:
            message_to_send = ''

            for gh in GITHUB_RE.finditer(message.content):
                d = gh.groupdict()
                headers = {'Accept': 'application/vnd.github.v3.raw'}
                if 'GITHUB_TOKEN' in os.environ:
                    headers['Authorization'] = f'token {os.environ["GITHUB_TOKEN"]}'
                file_contents = await fetch_http(
                    self.session,
                    f'https://api.github.com/repos/{d["repo"]}\
                        /contents/{d["file_path"]}?ref={d["branch"]}',
                    'text',
                    headers=headers,
                )
                message_to_send += await snippet_to_embed(d, file_contents)

            for gh_gist in GITHUB_GIST_RE.finditer(message.content):
                d = gh_gist.groupdict()
                gist_json = await fetch_http(
                    self.session,
                    f'https://api.github.com/gists/{d["gist_id"]}\
                        {"/" + d["revision"] if len(d["revision"]) > 0 else ""}',
                    'json',
                )
                for f in gist_json['files']:
                    if d['file_path'] == f.lower().replace('.', '-'):
                        d['file_path'] = f
                        file_contents = await fetch_http(
                            self.session,
                            gist_json['files'][f]['raw_url'],
                            'text',
                        )
                        message_to_send += await snippet_to_embed(d, file_contents)
                        break

            for gl in GITLAB_RE.finditer(message.content):
                d = gl.groupdict()
                await orig_to_encode(d)
                headers = {}
                if 'GITLAB_TOKEN' in os.environ:
                    headers['PRIVATE-TOKEN'] = os.environ["GITLAB_TOKEN"]
                file_contents = await fetch_http(
                    self.session,
                    f'https://gitlab.com/api/v4/projects/{d["repo"]}/\
                        repository/files/{d["file_path"]}/raw?ref={d["branch"]}',
                    'text',
                    headers=headers,
                )
                await revert_to_orig(d)
                message_to_send += await snippet_to_embed(d, file_contents)

            for bb in BITBUCKET_RE.finditer(message.content):
                d = bb.groupdict()
                await orig_to_encode(d)
                file_contents = await fetch_http(
                    self.session,
                    f'https://bitbucket.org/{d["repo"]}/raw/{d["branch"]}/{d["file_path"]}',
                    'text',
                )
                await revert_to_orig(d)
                message_to_send += await snippet_to_embed(d, file_contents)

            message_to_send = message_to_send[:-1]

            if len(message_to_send) > 2000:
                await message.channel.send(
                    'Sorry, Discord has a 2000 character limit. Please send a shorter '
                    + 'snippet or split the big snippet up into several smaller ones :slight_smile:'
                )
            elif len(message_to_send) == 0:
                await message.channel.send(
                    'Please send valid snippet links to prevent spam :slight_smile:'
                )
            elif message_to_send.count('\n') > 50:
                await message.channel.send(
                    'Please limit the total number of lines to at most 50 to prevent spam :slight_smile:'
                )
            else:
                await message.channel.send(message_to_send)
            await message.edit(suppress=True)


def setup(bot: Bot) -> None:
    """Load the Utils cog."""
    bot.add_cog(PrintSnippets(bot))