aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar dolphingarlic <[email protected]>2020-07-30 20:13:15 +0200
committerGravatar dolphingarlic <[email protected]>2020-07-30 20:13:15 +0200
commitb759a940a097effd16b761e0c62231ae0ca9562b (patch)
tree0a1c2f31f7cbad0992b4bf3899e6ff7814c3690d
parentRemoved repo widget prettification and added reaction to remove lines (diff)
Cleaned the code for CodeSnippets
-rw-r--r--bot/__main__.py2
-rw-r--r--bot/cogs/code_snippets.py216
-rw-r--r--bot/cogs/print_snippets.py190
3 files changed, 217 insertions, 191 deletions
diff --git a/bot/__main__.py b/bot/__main__.py
index 3191faf85..3d414c4b8 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -71,7 +71,7 @@ bot.load_extension("bot.cogs.utils")
bot.load_extension("bot.cogs.watchchannels")
bot.load_extension("bot.cogs.webhook_remover")
bot.load_extension("bot.cogs.wolfram")
-bot.load_extension("bot.cogs.print_snippets")
+bot.load_extension("bot.cogs.code_snippets")
if constants.HelpChannels.enable:
bot.load_extension("bot.cogs.help_channels")
diff --git a/bot/cogs/code_snippets.py b/bot/cogs/code_snippets.py
new file mode 100644
index 000000000..9bd06f6ff
--- /dev/null
+++ b/bot/cogs/code_snippets.py
@@ -0,0 +1,216 @@
+import re
+import textwrap
+from urllib.parse import quote_plus
+
+from aiohttp import ClientSession
+from discord import Message
+from discord.ext.commands import Cog
+
+from bot.bot import Bot
+from bot.utils.messages import wait_for_deletion
+
+
+async def fetch_http(session: ClientSession, url: str, response_format: str, **kwargs) -> str:
+ """Uses aiohttp to make http GET requests."""
+ async with session.get(url, **kwargs) as response:
+ if response_format == 'text':
+ return await response.text()
+ elif response_format == 'json':
+ return await response.json()
+
+
+async def fetch_github_snippet(session: ClientSession, repo: str,
+ path: str, start_line: str, end_line: str) -> str:
+ """Fetches a snippet from a GitHub repo."""
+ headers = {'Accept': 'application/vnd.github.v3.raw'}
+
+ # Search the GitHub API for the specified branch
+ refs = (await fetch_http(session, f'https://api.github.com/repos/{repo}/branches', 'json', headers=headers)
+ + await fetch_http(session, f'https://api.github.com/repos/{repo}/tags', 'json', headers=headers))
+
+ ref = path.split('/')[0]
+ file_path = '/'.join(path.split('/')[1:])
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+
+ file_contents = await fetch_http(
+ session,
+ f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
+ 'text',
+ headers=headers,
+ )
+
+ return await snippet_to_md(file_contents, file_path, start_line, end_line)
+
+
+async def fetch_github_gist_snippet(session: ClientSession, gist_id: str, revision: str,
+ file_path: str, start_line: str, end_line: str) -> str:
+ """Fetches a snippet from a GitHub gist."""
+ headers = {'Accept': 'application/vnd.github.v3.raw'}
+
+ gist_json = await fetch_http(
+ session,
+ f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
+ 'json',
+ headers=headers,
+ )
+
+ # Check each file in the gist for the specified file
+ for gist_file in gist_json['files']:
+ if file_path == gist_file.lower().replace('.', '-'):
+ file_contents = await fetch_http(
+ session,
+ gist_json['files'][gist_file]['raw_url'],
+ 'text',
+ )
+
+ return await snippet_to_md(file_contents, gist_file, start_line, end_line)
+
+ return ''
+
+
+async def fetch_gitlab_snippet(session: ClientSession, repo: str,
+ path: str, start_line: str, end_line: str) -> str:
+ """Fetches a snippet from a GitLab repo."""
+ enc_repo = quote_plus(repo)
+
+ # Searches the GitLab API for the specified branch
+ refs = (await fetch_http(session, f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches', 'json')
+ + await fetch_http(session, f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json'))
+
+ ref = path.split('/')[0]
+ file_path = '/'.join(path.split('/')[1:])
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+
+ enc_ref = quote_plus(ref)
+ enc_file_path = quote_plus(file_path)
+
+ file_contents = await fetch_http(
+ session,
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
+ 'text',
+ )
+
+ return await snippet_to_md(file_contents, file_path, start_line, end_line)
+
+
+async def fetch_bitbucket_snippet(session: ClientSession, repo: str, ref: str,
+ file_path: str, start_line: int, end_line: int) -> str:
+ """Fetches a snippet from a BitBucket repo."""
+ file_contents = await fetch_http(
+ session,
+ f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
+ 'text',
+ )
+
+ return await snippet_to_md(file_contents, file_path, start_line, end_line)
+
+
+async def snippet_to_md(file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
+ """Given file contents, file path, start line and end line creates a code block."""
+ # Parse start_line and end_line into integers
+ if end_line is None:
+ start_line = end_line = int(start_line)
+ else:
+ start_line = int(start_line)
+ end_line = int(end_line)
+
+ split_file_contents = file_contents.splitlines()
+
+ # Make sure that the specified lines are in range
+ if start_line > end_line:
+ start_line, end_line = end_line, start_line
+ if start_line > len(split_file_contents) or end_line < 1:
+ return ''
+ start_line = max(1, start_line)
+ end_line = min(len(split_file_contents), end_line)
+
+ # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
+ required = '\n'.join(split_file_contents[start_line - 1:end_line])
+ required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
+
+ # Extracts the code language and checks whether it's a "valid" language
+ language = file_path.split('/')[-1].split('.')[-1]
+ if not language.replace('-', '').replace('+', '').replace('_', '').isalnum():
+ language = ''
+
+ if len(required) != 0:
+ return f'```{language}\n{required}```\n'
+ return ''
+
+
+GITHUB_RE = re.compile(
+ r'https://github\.com/(?P<repo>.+?)/blob/(?P<path>.+/.+)'
+ r'#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
+)
+
+GITHUB_GIST_RE = re.compile(
+ r'https://gist\.github\.com/([^/]+)/(?P<gist_id>[^\W_]+)/*'
+ r'(?P<revision>[^\W_]*)/*#file-(?P<file_path>.+?)'
+ r'-L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
+)
+
+GITLAB_RE = re.compile(
+ r'https://gitlab\.com/(?P<repo>.+?)/\-/blob/(?P<path>.+/.+)'
+ r'#L(?P<start_line>\d+)([-](?P<end_line>\d+))?\b'
+)
+
+BITBUCKET_RE = re.compile(
+ r'https://bitbucket\.org/(?P<repo>.+?)/src/(?P<ref>.+?)/'
+ r'(?P<file_path>.+?)#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?\b'
+)
+
+
+class CodeSnippets(Cog):
+ """
+ Cog that prints out snippets to Discord.
+
+ Matches each message against a regex and prints the contents of all matched snippets.
+ """
+
+ def __init__(self, bot: Bot):
+ """Initializes the cog's bot."""
+ self.bot = bot
+
+ @Cog.listener()
+ async def on_message(self, message: Message) -> None:
+ """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
+ gh_match = GITHUB_RE.search(message.content)
+ gh_gist_match = GITHUB_GIST_RE.search(message.content)
+ gl_match = GITLAB_RE.search(message.content)
+ bb_match = BITBUCKET_RE.search(message.content)
+
+ if (gh_match or gh_gist_match or gl_match or bb_match) and not message.author.bot:
+ message_to_send = ''
+
+ for gh in GITHUB_RE.finditer(message.content):
+ message_to_send += await fetch_github_snippet(self.bot.http_session, **gh.groupdict())
+
+ for gh_gist in GITHUB_GIST_RE.finditer(message.content):
+ message_to_send += await fetch_github_gist_snippet(self.bot.http_session, **gh_gist.groupdict())
+
+ for gl in GITLAB_RE.finditer(message.content):
+ message_to_send += await fetch_gitlab_snippet(self.bot.http_session, **gl.groupdict())
+
+ for bb in BITBUCKET_RE.finditer(message.content):
+ message_to_send += await fetch_bitbucket_snippet(self.bot.http_session, **bb.groupdict())
+
+ if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15:
+ await message.edit(suppress=True)
+ await wait_for_deletion(
+ await message.channel.send(message_to_send),
+ (message.author.id,),
+ client=self.bot
+ )
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeSnippets cog."""
+ bot.add_cog(CodeSnippets(bot))
diff --git a/bot/cogs/print_snippets.py b/bot/cogs/print_snippets.py
deleted file mode 100644
index 3f784d2c6..000000000
--- a/bot/cogs/print_snippets.py
+++ /dev/null
@@ -1,190 +0,0 @@
-import asyncio
-import os
-import re
-import textwrap
-
-import aiohttp
-from discord import Message, Reaction, User
-from discord.ext.commands import Cog
-
-from bot.bot import Bot
-
-
-async def fetch_http(session: aiohttp.ClientSession, url: str, response_format: str, **kwargs) -> str:
- """Uses aiohttp to make http GET requests."""
- async with session.get(url, **kwargs) as response:
- if response_format == 'text':
- return await response.text()
- elif response_format == 'json':
- return await response.json()
-
-
-async def revert_to_orig(d: dict) -> dict:
- """Replace URL Encoded values back to their original."""
- for obj in d:
- if d[obj] is not None:
- d[obj] = d[obj].replace('%2F', '/').replace('%2E', '.')
-
-
-async def orig_to_encode(d: dict) -> dict:
- """Encode URL Parameters."""
- for obj in d:
- if d[obj] is not None:
- d[obj] = d[obj].replace('/', '%2F').replace('.', '%2E')
-
-
-async def snippet_to_embed(d: dict, file_contents: str) -> str:
- """Given a regex groupdict and file contents, creates a code block."""
- if d['end_line']:
- start_line = int(d['start_line'])
- end_line = int(d['end_line'])
- else:
- start_line = end_line = int(d['start_line'])
-
- split_file_contents = file_contents.split('\n')
-
- if start_line > end_line:
- start_line, end_line = end_line, start_line
- if start_line > len(split_file_contents) or end_line < 1:
- return ''
- start_line = max(1, start_line)
- end_line = min(len(split_file_contents), end_line)
-
- required = '\n'.join(split_file_contents[start_line - 1:end_line])
- required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
-
- language = d['file_path'].split('/')[-1].split('.')[-1]
- if not language.replace('-', '').replace('+', '').replace('_', '').isalnum():
- language = ''
-
- if len(required) != 0:
- return f'```{language}\n{required}```\n'
- return '``` ```\n'
-
-
-GITHUB_RE = re.compile(
- r'https://github\.com/(?P<repo>.+?)/blob/(?P<branch>.+?)/'
- + r'(?P<file_path>.+?)#L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
-)
-
-GITHUB_GIST_RE = re.compile(
- r'https://gist\.github\.com/([^/]*)/(?P<gist_id>[0-9a-zA-Z]+)/*'
- + r'(?P<revision>[0-9a-zA-Z]*)/*#file-(?P<file_path>.+?)'
- + r'-L(?P<start_line>\d+)([-~]L(?P<end_line>\d+))?\b'
-)
-
-GITLAB_RE = re.compile(
- r'https://gitlab\.com/(?P<repo>.+?)/\-/blob/(?P<branch>.+?)/'
- + r'(?P<file_path>.+?)#L(?P<start_line>\d+)([-~](?P<end_line>\d+))?\b'
-)
-
-BITBUCKET_RE = re.compile(
- r'https://bitbucket\.org/(?P<repo>.+?)/src/(?P<branch>.+?)/'
- + r'(?P<file_path>.+?)#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?\b'
-)
-
-
-class PrintSnippets(Cog):
- """
- Cog that prints out snippets to Discord.
-
- Matches each message against a regex and prints the contents of all matched snippets.
- """
-
- def __init__(self, bot: Bot):
- """Initializes the cog's bot."""
- self.bot = bot
- self.session = aiohttp.ClientSession()
-
- @Cog.listener()
- async def on_message(self, message: Message) -> None:
- """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
- gh_match = GITHUB_RE.search(message.content)
- gh_gist_match = GITHUB_GIST_RE.search(message.content)
- gl_match = GITLAB_RE.search(message.content)
- bb_match = BITBUCKET_RE.search(message.content)
-
- if (gh_match or gh_gist_match or gl_match or bb_match) and not message.author.bot:
- message_to_send = ''
-
- for gh in GITHUB_RE.finditer(message.content):
- d = gh.groupdict()
- headers = {'Accept': 'application/vnd.github.v3.raw'}
- if 'GITHUB_TOKEN' in os.environ:
- headers['Authorization'] = f'token {os.environ["GITHUB_TOKEN"]}'
- file_contents = await fetch_http(
- self.session,
- f'https://api.github.com/repos/{d["repo"]}'
- + f'/contents/{d["file_path"]}?ref={d["branch"]}',
- 'text',
- headers=headers,
- )
- message_to_send += await snippet_to_embed(d, file_contents)
-
- for gh_gist in GITHUB_GIST_RE.finditer(message.content):
- d = gh_gist.groupdict()
- gist_json = await fetch_http(
- self.session,
- f'https://api.github.com/gists/{d["gist_id"]}'
- + f'{"/" + d["revision"] if len(d["revision"]) > 0 else ""}',
- 'json',
- )
- for f in gist_json['files']:
- if d['file_path'] == f.lower().replace('.', '-'):
- d['file_path'] = f
- file_contents = await fetch_http(
- self.session,
- gist_json['files'][f]['raw_url'],
- 'text',
- )
- message_to_send += await snippet_to_embed(d, file_contents)
- break
-
- for gl in GITLAB_RE.finditer(message.content):
- d = gl.groupdict()
- await orig_to_encode(d)
- headers = {}
- if 'GITLAB_TOKEN' in os.environ:
- headers['PRIVATE-TOKEN'] = os.environ["GITLAB_TOKEN"]
- file_contents = await fetch_http(
- self.session,
- f'https://gitlab.com/api/v4/projects/{d["repo"]}/'
- + f'repository/files/{d["file_path"]}/raw?ref={d["branch"]}',
- 'text',
- headers=headers,
- )
- await revert_to_orig(d)
- message_to_send += await snippet_to_embed(d, file_contents)
-
- for bb in BITBUCKET_RE.finditer(message.content):
- d = bb.groupdict()
- await orig_to_encode(d)
- file_contents = await fetch_http(
- self.session,
- f'https://bitbucket.org/{d["repo"]}/raw/{d["branch"]}/{d["file_path"]}',
- 'text',
- )
- await revert_to_orig(d)
- message_to_send += await snippet_to_embed(d, file_contents)
-
- message_to_send = message_to_send[:-1]
-
- if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 50:
- sent_message = await message.channel.send(message_to_send)
- await message.edit(suppress=True)
- await sent_message.add_reaction('❌')
-
- def check(reaction: Reaction, user: User) -> bool:
- return user == message.author and str(reaction.emoji) == '❌'
-
- try:
- reaction, user = await self.bot.wait_for('reaction_add', timeout=10.0, check=check)
- except asyncio.TimeoutError:
- await sent_message.remove_reaction('❌', self.bot.user)
- else:
- await sent_message.delete()
-
-
-def setup(bot: Bot) -> None:
- """Load the Utils cog."""
- bot.add_cog(PrintSnippets(bot))