aboutsummaryrefslogtreecommitdiffstats
path: root/bot
diff options
context:
space:
mode:
Diffstat (limited to 'bot')
-rw-r--r--bot/exts/fun/latex/__init__.py7
-rw-r--r--bot/exts/fun/latex/_latex_cog.py60
-rw-r--r--bot/exts/fun/latex/_renderer.py45
3 files changed, 112 insertions, 0 deletions
diff --git a/bot/exts/fun/latex/__init__.py b/bot/exts/fun/latex/__init__.py
new file mode 100644
index 00000000..e58e0447
--- /dev/null
+++ b/bot/exts/fun/latex/__init__.py
@@ -0,0 +1,7 @@
+from bot.bot import Bot
+from bot.exts.fun.latex._latex_cog import Latex
+
+
+def setup(bot: Bot) -> None:
+ """Load the Latex Cog."""
+ bot.add_cog(Latex(bot))
diff --git a/bot/exts/fun/latex/_latex_cog.py b/bot/exts/fun/latex/_latex_cog.py
new file mode 100644
index 00000000..239f499c
--- /dev/null
+++ b/bot/exts/fun/latex/_latex_cog.py
@@ -0,0 +1,60 @@
+import asyncio
+import hashlib
+import sys
+from pathlib import Path
+import re
+
+import discord
+from discord.ext import commands
+
+
+FORMATTED_CODE_REGEX = re.compile(
+ r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
+ r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
+ r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
+ r"(?P<code>.*?)" # extract all code inside the markup
+ r"\s*" # any more whitespace before the end of the code markup
+ r"(?P=delim)", # match the exact same delimiter from the start again
+ re.DOTALL | re.IGNORECASE, # "." also matches newlines, case insensitive
+)
+
+THIS_DIR = Path(__file__).parent
+CACHE_DIRECTORY = THIS_DIR / "cache"
+CACHE_DIRECTORY.mkdir(exist_ok=True)
+
+
+def _prepare_input(text: str) -> str:
+ text = text.replace(r"\\", "$\n$") # matplotlib uses \n for newlines, not \\
+
+ if match := FORMATTED_CODE_REGEX.match(text):
+ return match.group("code")
+ else:
+ return text
+
+
+class Latex(commands.Cog):
+ """Renders latex."""
+ @commands.command()
+ @commands.max_concurrency(1, commands.BucketType.guild, wait=True)
+ async def latex(self, ctx: commands.Context, *, query: str) -> None:
+ """Renders the text in latex and sends the image."""
+ query = _prepare_input(query)
+ query_hash = hashlib.md5(query.encode()).hexdigest()
+ image_path = CACHE_DIRECTORY / f"{query_hash}.png"
+ async with ctx.typing():
+ if not image_path.exists():
+ proc = await asyncio.subprocess.create_subprocess_exec(
+ sys.executable,
+ "_renderer.py",
+ query,
+ image_path.relative_to(THIS_DIR),
+ cwd=THIS_DIR,
+ stderr=asyncio.subprocess.PIPE
+ )
+ return_code = await proc.wait()
+ if return_code != 0:
+ image_path.unlink()
+ err = (await proc.stderr.read()).decode()
+ raise commands.BadArgument(err)
+
+ await ctx.send(file=discord.File(image_path, "latex.png"))
diff --git a/bot/exts/fun/latex/_renderer.py b/bot/exts/fun/latex/_renderer.py
new file mode 100644
index 00000000..3f6528ad
--- /dev/null
+++ b/bot/exts/fun/latex/_renderer.py
@@ -0,0 +1,45 @@
+import sys
+
+from pathlib import Path
+from typing import BinaryIO
+
+import matplotlib.pyplot as plt
+
+# configure fonts and colors for matplotlib
+plt.rcParams.update(
+ {
+ "font.size": 16,
+ "mathtext.fontset": "cm", # Computer Modern font set
+ "mathtext.rm": "serif",
+ "figure.facecolor": "36393F", # matches Discord's dark mode background color
+ "text.color": "white",
+ }
+)
+
+
+def render(text: str, file_handle: BinaryIO) -> None:
+ """
+ Saves rendered image in `file_handle`. In case the input is invalid latex, it prints the error to `stderr`.
+ """
+ fig = plt.figure()
+ fig.text(0, 1, text, horizontalalignment="left", verticalalignment="top")
+ try:
+ plt.savefig(file_handle, bbox_inches="tight", dpi=600)
+ except ValueError as err:
+ # get rid of traceback, keeping just the latex error
+ sys.exit(err)
+
+
+def main():
+ """
+ Renders a latex query and saves the output in a specified file.
+ Expects two command line arguments: the query and the path to the output file.
+ """
+ query = sys.argv[1]
+ out_file_path = Path(sys.argv[2])
+ with open(out_file_path, "wb") as out_file:
+ render(query, out_file)
+
+
+if __name__ == "__main__":
+ main()