aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils/__init__.py
blob: 72a681a378875eff984eb9cfece57675498ded65 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import asyncio
import re
import string
from typing import List

import discord
from discord.ext.commands import BadArgument, Context

from bot.pagination import LinePaginator


async def disambiguate(
        ctx: Context, entries: List[str], *, timeout: float = 30,
        entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None
) -> str:
    """
    Has the user choose between multiple entries in case one could not be chosen automatically.

    Disambiguation will be canceled after `timeout` seconds.

    This will raise a BadArgument if entries is empty, if the disambiguation event times out,
    or if the user makes an invalid choice.
    """
    if len(entries) == 0:
        raise BadArgument('No matches found.')

    if len(entries) == 1:
        return entries[0]

    choices = (f'{index}: {entry}' for index, entry in enumerate(entries, start=1))

    def check(message):
        return (message.content.isdigit()
                and message.author == ctx.author
                and message.channel == ctx.channel)

    try:
        if embed is None:
            embed = discord.Embed()

        coro1 = ctx.bot.wait_for('message', check=check, timeout=timeout)
        coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=entries_per_page,
                                       empty=empty, max_size=6000, timeout=9000)

        # wait_for timeout will go to except instead of the wait_for thing as I expected
        futures = [asyncio.ensure_future(coro1), asyncio.ensure_future(coro2)]
        done, pending = await asyncio.wait(futures, return_when=asyncio.FIRST_COMPLETED, loop=ctx.bot.loop)

        # :yert:
        result = list(done)[0].result()

        # Pagination was canceled - result is None
        if result is None:
            for coro in pending:
                coro.cancel()
            raise BadArgument('Canceled.')

        # Pagination was not initiated, only one page
        if result.author == ctx.bot.user:
            # Continue the wait_for
            result = await list(pending)[0]

        # Love that duplicate code
        for coro in pending:
            coro.cancel()
    except asyncio.TimeoutError:
        raise BadArgument('Timed out.')

    # Guaranteed to not error because of isdigit() in check
    index = int(result.content)

    try:
        return entries[index - 1]
    except IndexError:
        raise BadArgument('Invalid choice.')


def replace_many(
        sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False
) -> str:
    """
    Replaces multiple substrings in a string given a mapping of strings.

    By default replaces long strings before short strings, and lowercase before uppercase.
    Example:
        var = replace_many("This is a sentence", {"is": "was", "This": "That"})
        assert var == "That was a sentence"

    If `ignore_case` is given, does a case insensitive match.
    Example:
        var = replace_many("THIS is a sentence", {"IS": "was", "tHiS": "That"}, ignore_case=True)
        assert var == "That was a sentence"

    If `match_case` is given, matches the case of the replacement with the replaced word.
    Example:
        var = replace_many(
            "This IS a sentence", {"is": "was", "this": "that"}, ignore_case=True, match_case=True
        )
        assert var == "That WAS a sentence"
    """
    if ignore_case:
        replacements = dict(
            (word.lower(), replacement) for word, replacement in replacements.items()
        )

    words_to_replace = sorted(replacements, key=lambda s: (-len(s), s))

    # Join and compile words to replace into a regex
    pattern = "|".join(re.escape(word) for word in words_to_replace)
    regex = re.compile(pattern, re.I if ignore_case else 0)

    def _repl(match):
        """Returns replacement depending on `ignore_case` and `match_case`."""
        word = match.group(0)
        replacement = replacements[word.lower() if ignore_case else word]

        if not match_case:
            return replacement

        # Clean punctuation from word so string methods work
        cleaned_word = word.translate(str.maketrans('', '', string.punctuation))
        if cleaned_word.isupper():
            return replacement.upper()
        elif cleaned_word[0].isupper():
            return replacement.capitalize()
        else:
            return replacement.lower()

    return regex.sub(_repl, sentence)