aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/__init__.py74
-rw-r--r--bot/utils/halloween/spookifications.py8
-rw-r--r--bot/utils/persist.py66
3 files changed, 132 insertions, 16 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index ef18a1b9..0aa50af6 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -1,4 +1,6 @@
import asyncio
+import re
+import string
from typing import List
import discord
@@ -9,21 +11,15 @@ from bot.pagination import LinePaginator
async def disambiguate(
ctx: Context, entries: List[str], *, timeout: float = 30,
- per_page: int = 20, empty: bool = False, embed: discord.Embed = None
-):
+ entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None
+) -> str:
"""
Has the user choose between multiple entries in case one could not be chosen automatically.
+ Disambiguation will be canceled after `timeout` seconds.
+
This will raise a BadArgument if entries is empty, if the disambiguation event times out,
or if the user makes an invalid choice.
-
- :param ctx: Context object from discord.py
- :param entries: List of items for user to choose from
- :param timeout: Number of seconds to wait before canceling disambiguation
- :param per_page: Entries per embed page
- :param empty: Whether the paginator should have an extra line between items
- :param embed: The embed that the paginator will use.
- :return: Users choice for correct entry.
"""
if len(entries) == 0:
raise BadArgument('No matches found.')
@@ -33,7 +29,7 @@ async def disambiguate(
choices = (f'{index}: {entry}' for index, entry in enumerate(entries, start=1))
- def check(message):
+ def check(message: discord.Message) -> bool:
return (message.content.isdigit()
and message.author == ctx.author
and message.channel == ctx.channel)
@@ -43,7 +39,7 @@ async def disambiguate(
embed = discord.Embed()
coro1 = ctx.bot.wait_for('message', check=check, timeout=timeout)
- coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=per_page,
+ coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=entries_per_page,
empty=empty, max_size=6000, timeout=9000)
# wait_for timeout will go to except instead of the wait_for thing as I expected
@@ -77,3 +73,57 @@ async def disambiguate(
return entries[index - 1]
except IndexError:
raise BadArgument('Invalid choice.')
+
+
+def replace_many(
+ sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False
+) -> str:
+ """
+ Replaces multiple substrings in a string given a mapping of strings.
+
+ By default replaces long strings before short strings, and lowercase before uppercase.
+ Example:
+ var = replace_many("This is a sentence", {"is": "was", "This": "That"})
+ assert var == "That was a sentence"
+
+ If `ignore_case` is given, does a case insensitive match.
+ Example:
+ var = replace_many("THIS is a sentence", {"IS": "was", "tHiS": "That"}, ignore_case=True)
+ assert var == "That was a sentence"
+
+ If `match_case` is given, matches the case of the replacement with the replaced word.
+ Example:
+ var = replace_many(
+ "This IS a sentence", {"is": "was", "this": "that"}, ignore_case=True, match_case=True
+ )
+ assert var == "That WAS a sentence"
+ """
+ if ignore_case:
+ replacements = dict(
+ (word.lower(), replacement) for word, replacement in replacements.items()
+ )
+
+ words_to_replace = sorted(replacements, key=lambda s: (-len(s), s))
+
+ # Join and compile words to replace into a regex
+ pattern = "|".join(re.escape(word) for word in words_to_replace)
+ regex = re.compile(pattern, re.I if ignore_case else 0)
+
+ def _repl(match: re.Match) -> str:
+ """Returns replacement depending on `ignore_case` and `match_case`."""
+ word = match.group(0)
+ replacement = replacements[word.lower() if ignore_case else word]
+
+ if not match_case:
+ return replacement
+
+ # Clean punctuation from word so string methods work
+ cleaned_word = word.translate(str.maketrans('', '', string.punctuation))
+ if cleaned_word.isupper():
+ return replacement.upper()
+ elif cleaned_word[0].isupper():
+ return replacement.capitalize()
+ else:
+ return replacement.lower()
+
+ return regex.sub(_repl, sentence)
diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py
index 69b49919..11f69850 100644
--- a/bot/utils/halloween/spookifications.py
+++ b/bot/utils/halloween/spookifications.py
@@ -7,7 +7,7 @@ from PIL import ImageOps
log = logging.getLogger()
-def inversion(im):
+def inversion(im: Image) -> Image:
"""
Inverts the image.
@@ -18,7 +18,7 @@ def inversion(im):
return inv
-def pentagram(im):
+def pentagram(im: Image) -> Image:
"""Adds pentagram to the image."""
im = im.convert('RGB')
wt, ht = im.size
@@ -28,7 +28,7 @@ def pentagram(im):
return im
-def bat(im):
+def bat(im: Image) -> Image:
"""
Adds a bat silhoutte to the image.
@@ -50,7 +50,7 @@ def bat(im):
return im
-def get_random_effect(im):
+def get_random_effect(im: Image) -> Image:
"""Randomly selects and applies an effect."""
effects = [inversion, pentagram, bat]
effect = choice(effects)
diff --git a/bot/utils/persist.py b/bot/utils/persist.py
new file mode 100644
index 00000000..a60a1219
--- /dev/null
+++ b/bot/utils/persist.py
@@ -0,0 +1,66 @@
+import sqlite3
+from pathlib import Path
+from shutil import copyfile
+
+from bot.seasons.season import get_seasons
+
+DIRECTORY = Path("data") # directory that has a persistent volume mapped to it
+
+
+def make_persistent(file_path: Path) -> Path:
+ """
+ Copy datafile at the provided file_path to the persistent data directory.
+
+ A persistent data file is needed by some features in order to not lose data
+ after bot rebuilds.
+
+ This function will ensure that a clean data file with default schema,
+ structure or data is copied over to the persistent volume before returning
+ the path to this new persistent version of the file.
+
+ If the persistent file already exists, it won't be overwritten with the
+ clean default file, just returning the Path instead to the existing file.
+
+ Note: Avoid using the same file name as other features in the same seasons
+ as otherwise only one datafile can be persistent and will be returned for
+ both cases.
+
+ Example Usage:
+ >>> import json
+ >>> template_datafile = Path("bot", "resources", "evergreen", "myfile.json")
+ >>> path_to_persistent_file = make_persistent(template_datafile)
+ >>> print(path_to_persistent_file)
+ data/evergreen/myfile.json
+ >>> with path_to_persistent_file.open("w+") as f:
+ >>> data = json.load(f)
+ """
+ # ensure the persistent data directory exists
+ DIRECTORY.mkdir(exist_ok=True)
+
+ if not file_path.is_file():
+ raise OSError(f"File not found at {file_path}.")
+
+ # detect season in datafile path for assigning to subdirectory
+ season = next((s for s in get_seasons() if s in file_path.parts), None)
+
+ if season:
+ # make sure subdirectory exists first
+ subdirectory = Path(DIRECTORY, season)
+ subdirectory.mkdir(exist_ok=True)
+
+ persistent_path = Path(subdirectory, file_path.name)
+
+ else:
+ persistent_path = Path(DIRECTORY, file_path.name)
+
+ # copy base/template datafile to persistent directory
+ if not persistent_path.exists():
+ copyfile(file_path, persistent_path)
+
+ return persistent_path
+
+
+def sqlite(db_path: Path) -> sqlite3.Connection:
+ """Copy sqlite file to the persistent data directory and return an open connection."""
+ persistent_path = make_persistent(db_path)
+ return sqlite3.connect(persistent_path)