diff options
Diffstat (limited to 'bot/utils')
-rw-r--r-- | bot/utils/__init__.py | 74 | ||||
-rw-r--r-- | bot/utils/halloween/spookifications.py | 8 | ||||
-rw-r--r-- | bot/utils/persist.py | 66 |
3 files changed, 132 insertions, 16 deletions
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index ef18a1b9..0aa50af6 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,4 +1,6 @@ import asyncio +import re +import string from typing import List import discord @@ -9,21 +11,15 @@ from bot.pagination import LinePaginator async def disambiguate( ctx: Context, entries: List[str], *, timeout: float = 30, - per_page: int = 20, empty: bool = False, embed: discord.Embed = None -): + entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None +) -> str: """ Has the user choose between multiple entries in case one could not be chosen automatically. + Disambiguation will be canceled after `timeout` seconds. + This will raise a BadArgument if entries is empty, if the disambiguation event times out, or if the user makes an invalid choice. - - :param ctx: Context object from discord.py - :param entries: List of items for user to choose from - :param timeout: Number of seconds to wait before canceling disambiguation - :param per_page: Entries per embed page - :param empty: Whether the paginator should have an extra line between items - :param embed: The embed that the paginator will use. - :return: Users choice for correct entry. """ if len(entries) == 0: raise BadArgument('No matches found.') @@ -33,7 +29,7 @@ async def disambiguate( choices = (f'{index}: {entry}' for index, entry in enumerate(entries, start=1)) - def check(message): + def check(message: discord.Message) -> bool: return (message.content.isdigit() and message.author == ctx.author and message.channel == ctx.channel) @@ -43,7 +39,7 @@ async def disambiguate( embed = discord.Embed() coro1 = ctx.bot.wait_for('message', check=check, timeout=timeout) - coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=per_page, + coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=entries_per_page, empty=empty, max_size=6000, timeout=9000) # wait_for timeout will go to except instead of the wait_for thing as I expected @@ -77,3 +73,57 @@ async def disambiguate( return entries[index - 1] except IndexError: raise BadArgument('Invalid choice.') + + +def replace_many( + sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False +) -> str: + """ + Replaces multiple substrings in a string given a mapping of strings. + + By default replaces long strings before short strings, and lowercase before uppercase. + Example: + var = replace_many("This is a sentence", {"is": "was", "This": "That"}) + assert var == "That was a sentence" + + If `ignore_case` is given, does a case insensitive match. + Example: + var = replace_many("THIS is a sentence", {"IS": "was", "tHiS": "That"}, ignore_case=True) + assert var == "That was a sentence" + + If `match_case` is given, matches the case of the replacement with the replaced word. + Example: + var = replace_many( + "This IS a sentence", {"is": "was", "this": "that"}, ignore_case=True, match_case=True + ) + assert var == "That WAS a sentence" + """ + if ignore_case: + replacements = dict( + (word.lower(), replacement) for word, replacement in replacements.items() + ) + + words_to_replace = sorted(replacements, key=lambda s: (-len(s), s)) + + # Join and compile words to replace into a regex + pattern = "|".join(re.escape(word) for word in words_to_replace) + regex = re.compile(pattern, re.I if ignore_case else 0) + + def _repl(match: re.Match) -> str: + """Returns replacement depending on `ignore_case` and `match_case`.""" + word = match.group(0) + replacement = replacements[word.lower() if ignore_case else word] + + if not match_case: + return replacement + + # Clean punctuation from word so string methods work + cleaned_word = word.translate(str.maketrans('', '', string.punctuation)) + if cleaned_word.isupper(): + return replacement.upper() + elif cleaned_word[0].isupper(): + return replacement.capitalize() + else: + return replacement.lower() + + return regex.sub(_repl, sentence) diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py index 69b49919..11f69850 100644 --- a/bot/utils/halloween/spookifications.py +++ b/bot/utils/halloween/spookifications.py @@ -7,7 +7,7 @@ from PIL import ImageOps log = logging.getLogger() -def inversion(im): +def inversion(im: Image) -> Image: """ Inverts the image. @@ -18,7 +18,7 @@ def inversion(im): return inv -def pentagram(im): +def pentagram(im: Image) -> Image: """Adds pentagram to the image.""" im = im.convert('RGB') wt, ht = im.size @@ -28,7 +28,7 @@ def pentagram(im): return im -def bat(im): +def bat(im: Image) -> Image: """ Adds a bat silhoutte to the image. @@ -50,7 +50,7 @@ def bat(im): return im -def get_random_effect(im): +def get_random_effect(im: Image) -> Image: """Randomly selects and applies an effect.""" effects = [inversion, pentagram, bat] effect = choice(effects) diff --git a/bot/utils/persist.py b/bot/utils/persist.py new file mode 100644 index 00000000..a60a1219 --- /dev/null +++ b/bot/utils/persist.py @@ -0,0 +1,66 @@ +import sqlite3 +from pathlib import Path +from shutil import copyfile + +from bot.seasons.season import get_seasons + +DIRECTORY = Path("data") # directory that has a persistent volume mapped to it + + +def make_persistent(file_path: Path) -> Path: + """ + Copy datafile at the provided file_path to the persistent data directory. + + A persistent data file is needed by some features in order to not lose data + after bot rebuilds. + + This function will ensure that a clean data file with default schema, + structure or data is copied over to the persistent volume before returning + the path to this new persistent version of the file. + + If the persistent file already exists, it won't be overwritten with the + clean default file, just returning the Path instead to the existing file. + + Note: Avoid using the same file name as other features in the same seasons + as otherwise only one datafile can be persistent and will be returned for + both cases. + + Example Usage: + >>> import json + >>> template_datafile = Path("bot", "resources", "evergreen", "myfile.json") + >>> path_to_persistent_file = make_persistent(template_datafile) + >>> print(path_to_persistent_file) + data/evergreen/myfile.json + >>> with path_to_persistent_file.open("w+") as f: + >>> data = json.load(f) + """ + # ensure the persistent data directory exists + DIRECTORY.mkdir(exist_ok=True) + + if not file_path.is_file(): + raise OSError(f"File not found at {file_path}.") + + # detect season in datafile path for assigning to subdirectory + season = next((s for s in get_seasons() if s in file_path.parts), None) + + if season: + # make sure subdirectory exists first + subdirectory = Path(DIRECTORY, season) + subdirectory.mkdir(exist_ok=True) + + persistent_path = Path(subdirectory, file_path.name) + + else: + persistent_path = Path(DIRECTORY, file_path.name) + + # copy base/template datafile to persistent directory + if not persistent_path.exists(): + copyfile(file_path, persistent_path) + + return persistent_path + + +def sqlite(db_path: Path) -> sqlite3.Connection: + """Copy sqlite file to the persistent data directory and return an open connection.""" + persistent_path = make_persistent(db_path) + return sqlite3.connect(persistent_path) |