diff options
Diffstat (limited to 'bot/utils')
| -rw-r--r-- | bot/utils/__init__.py | 43 | ||||
| -rw-r--r-- | bot/utils/checks.py | 30 | ||||
| -rw-r--r-- | bot/utils/converters.py | 105 | ||||
| -rw-r--r-- | bot/utils/decorators.py | 12 | ||||
| -rw-r--r-- | bot/utils/exceptions.py | 2 | ||||
| -rw-r--r-- | bot/utils/extensions.py | 10 | ||||
| -rw-r--r-- | bot/utils/halloween/spookifications.py | 10 | ||||
| -rw-r--r-- | bot/utils/helpers.py | 8 | ||||
| -rw-r--r-- | bot/utils/messages.py | 19 | ||||
| -rw-r--r-- | bot/utils/pagination.py | 26 | ||||
| -rw-r--r-- | bot/utils/time.py | 2 | 
11 files changed, 215 insertions, 52 deletions
| diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 35ef0a7b..bef12d25 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -3,7 +3,7 @@ import contextlib  import re  import string  from datetime import datetime -from typing import Iterable, List +from typing import Iterable, List, Optional  import discord  from discord.ext.commands import BadArgument, Context @@ -31,8 +31,13 @@ def resolve_current_month() -> Month:  async def disambiguate( -        ctx: Context, entries: List[str], *, timeout: float = 30, -        entries_per_page: int = 20, empty: bool = False, embed: discord.Embed = None +    ctx: Context, +    entries: List[str], +    *, +    timeout: float = 30, +    entries_per_page: int = 20, +    empty: bool = False, +    embed: Optional[discord.Embed] = None  ) -> str:      """      Has the user choose between multiple entries in case one could not be chosen automatically. @@ -43,25 +48,29 @@ async def disambiguate(      or if the user makes an invalid choice.      """      if len(entries) == 0: -        raise BadArgument('No matches found.') +        raise BadArgument("No matches found.")      if len(entries) == 1:          return entries[0] -    choices = (f'{index}: {entry}' for index, entry in enumerate(entries, start=1)) +    choices = (f"{index}: {entry}" for index, entry in enumerate(entries, start=1))      def check(message: discord.Message) -> bool: -        return (message.content.isdigit() -                and message.author == ctx.author -                and message.channel == ctx.channel) +        return ( +            message.content.isdecimal() +            and message.author == ctx.author +            and message.channel == ctx.channel +        )      try:          if embed is None:              embed = discord.Embed() -        coro1 = ctx.bot.wait_for('message', check=check, timeout=timeout) -        coro2 = LinePaginator.paginate(choices, ctx, embed=embed, max_lines=entries_per_page, -                                       empty=empty, max_size=6000, timeout=9000) +        coro1 = ctx.bot.wait_for("message", check=check, timeout=timeout) +        coro2 = LinePaginator.paginate( +            choices, ctx, embed=embed, max_lines=entries_per_page, +            empty=empty, max_size=6000, timeout=9000 +        )          # wait_for timeout will go to except instead of the wait_for thing as I expected          futures = [asyncio.ensure_future(coro1), asyncio.ensure_future(coro2)] @@ -74,7 +83,7 @@ async def disambiguate(          if result is None:              for coro in pending:                  coro.cancel() -            raise BadArgument('Canceled.') +            raise BadArgument("Canceled.")          # Pagination was not initiated, only one page          if result.author == ctx.bot.user: @@ -85,19 +94,19 @@ async def disambiguate(          for coro in pending:              coro.cancel()      except asyncio.TimeoutError: -        raise BadArgument('Timed out.') +        raise BadArgument("Timed out.") -    # Guaranteed to not error because of isdigit() in check +    # Guaranteed to not error because of isdecimal() in check      index = int(result.content)      try:          return entries[index - 1]      except IndexError: -        raise BadArgument('Invalid choice.') +        raise BadArgument("Invalid choice.")  def replace_many( -        sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False +    sentence: str, replacements: dict, *, ignore_case: bool = False, match_case: bool = False  ) -> str:      """      Replaces multiple substrings in a string given a mapping of strings. @@ -139,7 +148,7 @@ def replace_many(              return replacement          # Clean punctuation from word so string methods work -        cleaned_word = word.translate(str.maketrans('', '', string.punctuation)) +        cleaned_word = word.translate(str.maketrans("", "", string.punctuation))          if cleaned_word.isupper():              return replacement.upper()          elif cleaned_word[0].isupper(): diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 9dd4dde0..c06b6870 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -92,8 +92,10 @@ def in_whitelist_check(  def with_role_check(ctx: Context, *role_ids: int) -> bool:      """Returns True if the user has any one of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM -        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " -                  "This command is restricted by the with_role decorator. Rejecting request.") +        log.trace( +            f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. " +            "This command is restricted by the with_role decorator. Rejecting request." +        )          return False      for role in ctx.author.roles: @@ -101,22 +103,28 @@ def with_role_check(ctx: Context, *role_ids: int) -> bool:              log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")              return True -    log.trace(f"{ctx.author} does not have the required role to use " -              f"the '{ctx.command.name}' command, so the request is rejected.") +    log.trace( +        f"{ctx.author} does not have the required role to use " +        f"the '{ctx.command.name}' command, so the request is rejected." +    )      return False  def without_role_check(ctx: Context, *role_ids: int) -> bool:      """Returns True if the user does not have any of the roles in role_ids."""      if not ctx.guild:  # Return False in a DM -        log.trace(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " -                  "This command is restricted by the without_role decorator. Rejecting request.") +        log.trace( +            f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " +            "This command is restricted by the without_role decorator. Rejecting request." +        )          return False      author_roles = [role.id for role in ctx.author.roles]      check = all(role not in author_roles for role in role_ids) -    log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. " -              f"The result of the without_role check was {check}.") +    log.trace( +        f"{ctx.author} tried to call the '{ctx.command.name}' command. " +        f"The result of the without_role check was {check}." +    )      return check @@ -154,8 +162,10 @@ def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketTy          #          # If the `before_invoke` detail is ever a problem then I can quickly just swap over.          if not isinstance(command, Command): -            raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. ' -                            'This means it has to be above the command decorator in the code.') +            raise TypeError( +                "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. " +                "This means it has to be above the command decorator in the code." +            )          command._before_invoke = predicate diff --git a/bot/utils/converters.py b/bot/utils/converters.py index 228714c9..fe2c980c 100644 --- a/bot/utils/converters.py +++ b/bot/utils/converters.py @@ -1,11 +1,14 @@ +from datetime import datetime +from typing import Tuple, Union +  import discord -from discord.ext.commands.converter import MessageConverter +from discord.ext import commands -class WrappedMessageConverter(MessageConverter): +class WrappedMessageConverter(commands.MessageConverter):      """A converter that handles embed-suppressed links like <http://example.com>.""" -    async def convert(self, ctx: discord.ext.commands.Context, argument: str) -> discord.Message: +    async def convert(self, ctx: commands.Context, argument: str) -> discord.Message:          """Wrap the commands.MessageConverter to handle <> delimited message links."""          # It's possible to wrap a message in [<>] as well, and it's supported because its easy          if argument.startswith("[") and argument.endswith("]"): @@ -14,3 +17,99 @@ class WrappedMessageConverter(MessageConverter):              argument = argument[1:-1]          return await super().convert(ctx, argument) + + +class CoordinateConverter(commands.Converter): +    """Converter for Coordinates.""" + +    @staticmethod +    async def convert(ctx: commands.Context, coordinate: str) -> Tuple[int, int]: +        """Take in a coordinate string and turn it into an (x, y) tuple.""" +        if len(coordinate) not in (2, 3): +            raise commands.BadArgument("Invalid co-ordinate provided.") + +        coordinate = coordinate.lower() +        if coordinate[0].isalpha(): +            digit = coordinate[1:] +            letter = coordinate[0] +        else: +            digit = coordinate[:-1] +            letter = coordinate[-1] + +        if not digit.isdecimal(): +            raise commands.BadArgument + +        x = ord(letter) - ord("a") +        y = int(digit) - 1 + +        if (not 0 <= x <= 9) or (not 0 <= y <= 9): +            raise commands.BadArgument +        return x, y + + +SourceType = Union[commands.Command, commands.Cog] + + +class SourceConverter(commands.Converter): +    """Convert an argument into a command or cog.""" + +    @staticmethod +    async def convert(ctx: commands.Context, argument: str) -> SourceType: +        """Convert argument into source object.""" +        cog = ctx.bot.get_cog(argument) +        if cog: +            return cog + +        cmd = ctx.bot.get_command(argument) +        if cmd: +            return cmd + +        raise commands.BadArgument( +            f"Unable to convert `{argument}` to valid command or Cog." +        ) + + +class DateConverter(commands.Converter): +    """Parse SOL or earth date (in format YYYY-MM-DD) into `int` or `datetime`. When invalid input, raise error.""" + +    @staticmethod +    async def convert(ctx: commands.Context, argument: str) -> Union[int, datetime]: +        """Parse date (SOL or earth) into `datetime` or `int`. When invalid value, raise error.""" +        if argument.isdecimal(): +            return int(argument) +        try: +            date = datetime.strptime(argument, "%Y-%m-%d") +        except ValueError: +            raise commands.BadArgument( +                f"Can't convert `{argument}` to `datetime` in format `YYYY-MM-DD` or `int` in SOL." +            ) +        return date + + +class Subreddit(commands.Converter): +    """Forces a string to begin with "r/" and checks if it's a valid subreddit.""" + +    @staticmethod +    async def convert(ctx: commands.Context, sub: str) -> str: +        """ +        Force sub to begin with "r/" and check if it's a valid subreddit. + +        If sub is a valid subreddit, return it prepended with "r/" +        """ +        sub = sub.lower() + +        if not sub.startswith("r/"): +            sub = f"r/{sub}" + +        resp = await ctx.bot.http_session.get( +            "https://www.reddit.com/subreddits/search.json", +            params={"q": sub} +        ) + +        json = await resp.json() +        if not json["data"]["children"]: +            raise commands.BadArgument( +                f"The subreddit `{sub}` either doesn't exist, or it has no posts." +            ) + +        return sub diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index c12a15ff..c0783144 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -11,7 +11,7 @@ from discord import Colour, Embed  from discord.ext import commands  from discord.ext.commands import CheckFailure, Command, Context -from bot.constants import ERROR_REPLIES, Month +from bot.constants import Channels, ERROR_REPLIES, Month, WHITELISTED_CHANNELS  from bot.utils import human_months, resolve_current_month  from bot.utils.checks import in_whitelist_check @@ -253,6 +253,12 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context],          channels = set(kwargs.get("channels") or {})          categories = kwargs.get("categories") +        # Only output override channels + community_bot_commands +        if channels: +            default_whitelist_channels = set(WHITELISTED_CHANNELS) +            default_whitelist_channels.discard(Channels.community_bot_commands) +            channels.difference_update(default_whitelist_channels) +          # Add all whitelisted category channels          if categories:              for category_id in categories: @@ -260,10 +266,10 @@ def whitelist_check(**default_kwargs: t.Container[int]) -> t.Callable[[Context],                  if category is None:                      continue -                [channels.add(channel.id) for channel in category.text_channels] +                channels.update(channel.id for channel in category.text_channels)          if channels: -            channels_str = ', '.join(f"<#{c_id}>" for c_id in channels) +            channels_str = ", ".join(f"<#{c_id}>" for c_id in channels)              message = f"Sorry, but you may only use this command within {channels_str}."          else:              message = "Sorry, but you may not use this command." diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index 2b1c1b31..9e080759 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -1,4 +1,4 @@  class UserNotPlayingError(Exception): -    """Will raised when user try to use game commands when not playing.""" +    """Raised when users try to use game commands when they are not playing."""      pass diff --git a/bot/utils/extensions.py b/bot/utils/extensions.py index 50350ea8..cd491c4b 100644 --- a/bot/utils/extensions.py +++ b/bot/utils/extensions.py @@ -3,6 +3,8 @@ import inspect  import pkgutil  from typing import Iterator, NoReturn +from discord.ext.commands import Context +  from bot import exts @@ -31,4 +33,12 @@ def walk_extensions() -> Iterator[str]:          yield module.name +async def invoke_help_command(ctx: Context) -> None: +    """Invoke the help command or default help command if help extensions is not loaded.""" +    if "bot.exts.evergreen.help" in ctx.bot.extensions: +        help_command = ctx.bot.get_command("help") +        await ctx.invoke(help_command, ctx.command.qualified_name) +        return +    await ctx.send_help(ctx.command) +  EXTENSIONS = frozenset(walk_extensions()) diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py index 11f69850..f69dd6fd 100644 --- a/bot/utils/halloween/spookifications.py +++ b/bot/utils/halloween/spookifications.py @@ -13,16 +13,16 @@ def inversion(im: Image) -> Image:      Returns an inverted image when supplied with an Image object.      """ -    im = im.convert('RGB') +    im = im.convert("RGB")      inv = ImageOps.invert(im)      return inv  def pentagram(im: Image) -> Image:      """Adds pentagram to the image.""" -    im = im.convert('RGB') +    im = im.convert("RGB")      wt, ht = im.size -    penta = Image.open('bot/resources/halloween/bloody-pentagram.png') +    penta = Image.open("bot/resources/halloween/bloody-pentagram.png")      penta = penta.resize((wt, ht))      im.paste(penta, (0, 0), penta)      return im @@ -35,9 +35,9 @@ def bat(im: Image) -> Image:      The bat silhoutte is of a size at least one-fifths that of the original image and may be rotated      up to 90 degrees anti-clockwise.      """ -    im = im.convert('RGB') +    im = im.convert("RGB")      wt, ht = im.size -    bat = Image.open('bot/resources/halloween/bat-clipart.png') +    bat = Image.open("bot/resources/halloween/bat-clipart.png")      bat_size = randint(wt//10, wt//7)      rot = randint(0, 90)      bat = bat.resize((bat_size, bat_size)) diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py new file mode 100644 index 00000000..74c2ccd0 --- /dev/null +++ b/bot/utils/helpers.py @@ -0,0 +1,8 @@ +import re + + +def suppress_links(message: str) -> str: +    """Accepts a message that may contain links, suppresses them, and returns them.""" +    for link in set(re.findall(r"https?://[^\s]+", message, re.IGNORECASE)): +        message = message.replace(link, f"<{link}>") +    return message diff --git a/bot/utils/messages.py b/bot/utils/messages.py new file mode 100644 index 00000000..a6c035f9 --- /dev/null +++ b/bot/utils/messages.py @@ -0,0 +1,19 @@ +import re +from typing import Optional + + +def sub_clyde(username: Optional[str]) -> Optional[str]: +    """ +    Replace "e"/"E" in any "clyde" in `username` with a Cyrillic "е"/"E" and return the new string. + +    Discord disallows "clyde" anywhere in the username for webhooks. It will return a 400. +    Return None only if `username` is None. +    """ +    def replace_e(match: re.Match) -> str: +        char = "е" if match[2] == "e" else "Е" +        return match[1] + char + +    if username: +        return re.sub(r"(clyd)(e)", replace_e, username, flags=re.I) +    else: +        return username  # Empty string or None diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index a4d0cc56..742281d7 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -4,6 +4,7 @@ from typing import Iterable, List, Optional, Tuple  from discord import Embed, Member, Reaction  from discord.abc import User +from discord.embeds import EmptyEmbed  from discord.ext.commands import Context, Paginator  from bot.constants import Emojis @@ -26,7 +27,7 @@ class EmptyPaginatorEmbed(Exception):  class LinePaginator(Paginator):      """A class that aids in paginating code blocks for Discord messages.""" -    def __init__(self, prefix: str = '```', suffix: str = '```', max_size: int = 2000, max_lines: int = None): +    def __init__(self, prefix: str = "```", suffix: str = "```", max_size: int = 2000, max_lines: int = None):          """          Overrides the Paginator.__init__ from inside discord.ext.commands. @@ -44,7 +45,7 @@ class LinePaginator(Paginator):          self._count = len(prefix) + 1  # prefix + newline          self._pages = [] -    def add_line(self, line: str = '', *, empty: bool = False) -> None: +    def add_line(self, line: str = "", *, empty: bool = False) -> None:          """          Adds a line to the current page. @@ -56,7 +57,7 @@ class LinePaginator(Paginator):          If `empty` is True, an empty line will be placed after the a given `line`.          """          if len(line) > self.max_size - len(self.prefix) - 2: -            raise RuntimeError('Line exceeds maximum page size %s' % (self.max_size - len(self.prefix) - 2)) +            raise RuntimeError("Line exceeds maximum page size %s" % (self.max_size - len(self.prefix) - 2))          if self.max_lines is not None:              if self._linecount >= self.max_lines: @@ -71,7 +72,7 @@ class LinePaginator(Paginator):          self._current_page.append(line)          if empty: -            self._current_page.append('') +            self._current_page.append("")              self._count += 1      @classmethod @@ -79,7 +80,7 @@ class LinePaginator(Paginator):                         prefix: str = "", suffix: str = "", max_lines: Optional[int] = None,                         max_size: int = 500, empty: bool = True, restrict_to_user: User = None,                         timeout: int = 300, footer_text: str = None, url: str = None, -                       exception_on_empty_embed: bool = False): +                       exception_on_empty_embed: bool = False) -> None:          """          Use a paginator and set of reactions to provide pagination over a set of lines. @@ -157,7 +158,8 @@ class LinePaginator(Paginator):                  log.trace(f"Setting embed url to '{url}'")              log.debug("There's less than two pages, so we won't paginate - sending single page on its own") -            return await ctx.send(embed=embed) +            await ctx.send(embed=embed) +            return          else:              if footer_text:                  embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -282,7 +284,7 @@ class ImagePaginator(Paginator):          self.images = []          self._pages = [] -    def add_line(self, line: str = '', *, empty: bool = False) -> None: +    def add_line(self, line: str = "", *, empty: bool = False) -> None:          """          Adds a line to each page, usually just 1 line in this context. @@ -302,7 +304,7 @@ class ImagePaginator(Paginator):      @classmethod      async def paginate(cls, pages: List[Tuple[str, str]], ctx: Context, embed: Embed,                         prefix: str = "", suffix: str = "", timeout: int = 300, -                       exception_on_empty_embed: bool = False): +                       exception_on_empty_embed: bool = False) -> None:          """          Use a paginator and set of reactions to provide pagination over a set of title/image pairs. @@ -352,7 +354,8 @@ class ImagePaginator(Paginator):              embed.set_image(url=image)          if len(paginator.pages) <= 1: -            return await ctx.send(embed=embed) +            await ctx.send(embed=embed) +            return          embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")          message = await ctx.send(embed=embed) @@ -417,9 +420,8 @@ class ImagePaginator(Paginator):              await message.edit(embed=embed)              embed.description = paginator.pages[current_page] -            image = paginator.images[current_page] -            if image: -                embed.set_image(url=image) +            image = paginator.images[current_page] or EmptyEmbed +            embed.set_image(url=image)              embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}")              log.debug(f"Got {reaction_type} page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") diff --git a/bot/utils/time.py b/bot/utils/time.py index 3c57e126..fbf2fd21 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -3,7 +3,7 @@ import datetime  from dateutil.relativedelta import relativedelta -# All these functions are from https://github.com/python-discord/bot/blob/master/bot/utils/time.py +# All these functions are from https://github.com/python-discord/bot/blob/main/bot/utils/time.py  def _stringify_time_unit(value: int, unit: str) -> str:      """      Returns a string to represent a value and time unit, ensuring that it uses the right plural form of the unit. | 
