diff options
Diffstat (limited to 'bot/utils')
| -rw-r--r-- | bot/utils/checks.py | 11 | ||||
| -rw-r--r-- | bot/utils/decorators.py | 31 | ||||
| -rw-r--r-- | bot/utils/exceptions.py | 7 | ||||
| -rw-r--r-- | bot/utils/halloween/spookifications.py | 3 | ||||
| -rw-r--r-- | bot/utils/members.py | 47 | ||||
| -rw-r--r-- | bot/utils/pagination.py | 10 | 
6 files changed, 76 insertions, 33 deletions
diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 612d1ed6..5433f436 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -4,14 +4,7 @@ from collections.abc import Container, Iterable  from typing import Callable, Optional  from discord.ext.commands import ( -    BucketType, -    CheckFailure, -    Cog, -    Command, -    CommandOnCooldown, -    Context, -    Cooldown, -    CooldownMapping, +    BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping  )  from bot import constants @@ -40,7 +33,7 @@ def in_whitelist_check(      channels: Container[int] = (),      categories: Container[int] = (),      roles: Container[int] = (), -    redirect: Optional[int] = constants.Channels.community_bot_commands, +    redirect: Optional[int] = constants.Channels.sir_lancebot_playground,      fail_silently: bool = False,  ) -> bool:      """ diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 132aaa87..8954e016 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -196,15 +196,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo      If `whitelist_override` is present, it is added to the global whitelist.      """      def predicate(ctx: Context) -> bool: -        # Skip DM invocations -        if not ctx.guild: -            log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") -            return True -          kwargs = default_kwargs.copy() +        allow_dms = False          # Update kwargs based on override          if hasattr(ctx.command.callback, "override"): +            # Handle DM invocations +            allow_dms = ctx.command.callback.override_dm +              # Remove default kwargs if reset is True              if ctx.command.callback.override_reset:                  kwargs = {} @@ -234,8 +233,12 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo                  f"invoked by {ctx.author}."              ) -        log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") -        result = in_whitelist_check(ctx, fail_silently=True, **kwargs) +        if ctx.guild is None: +            log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.") +            result = allow_dms +        else: +            log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") +            result = in_whitelist_check(ctx, fail_silently=True, **kwargs)          # Return if check passed          if result: @@ -254,14 +257,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo          channels = set(kwargs.get("channels") or {})          categories = kwargs.get("categories") -        # Only output override channels + community_bot_commands +        # Only output override channels + sir_lancebot_playground          if channels:              default_whitelist_channels = set(WHITELISTED_CHANNELS) -            default_whitelist_channels.discard(Channels.community_bot_commands) +            default_whitelist_channels.discard(Channels.sir_lancebot_playground)              channels.difference_update(default_whitelist_channels) -        # Add all whitelisted category channels -        if categories: +        # Add all whitelisted category channels, but skip if we're in DMs +        if categories and ctx.guild is not None:              for category_id in categories:                  category = ctx.guild.get_channel(category_id)                  if category is None: @@ -280,18 +283,22 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo      return predicate -def whitelist_override(bypass_defaults: bool = False, **kwargs: Container[int]) -> Callable: +def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable:      """      Override global whitelist context, with the kwargs specified.      All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.      Set `bypass_defaults` to True if you want to completely bypass global checks. +    Set `allow_dm` to True if you want to allow the command to be invoked from within direct messages. +    Note that you have to be careful with any references to the guild. +      This decorator has to go before (below) below the `command` decorator.      """      def inner(func: Callable) -> Callable:          func.override = kwargs          func.override_reset = bypass_defaults +        func.override_dm = allow_dm          return func      return inner diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index bf0e5813..3cd96325 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -15,3 +15,10 @@ class APIError(Exception):          self.api = api          self.status_code = status_code          self.error_msg = error_msg + + +class MovedCommandError(Exception): +    """Raised when a command has moved locations.""" + +    def __init__(self, new_command_name: str): +        self.new_command_name = new_command_name diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py index 93c5ddb9..c45ef8dc 100644 --- a/bot/utils/halloween/spookifications.py +++ b/bot/utils/halloween/spookifications.py @@ -1,8 +1,7 @@  import logging  from random import choice, randint -from PIL import Image -from PIL import ImageOps +from PIL import Image, ImageOps  log = logging.getLogger() diff --git a/bot/utils/members.py b/bot/utils/members.py new file mode 100644 index 00000000..de5850ca --- /dev/null +++ b/bot/utils/members.py @@ -0,0 +1,47 @@ +import logging +import typing as t + +import discord + +log = logging.getLogger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]: +    """ +    Attempt to get a member from cache; on failure fetch from the API. + +    Return `None` to indicate the member could not be found. +    """ +    if member := guild.get_member(member_id): +        log.trace("%s retrieved from cache.", member) +    else: +        try: +            member = await guild.fetch_member(member_id) +        except discord.errors.NotFound: +            log.trace("Failed to fetch %d from API.", member_id) +            return None +        log.trace("%s fetched from API.", member) +    return member + + +async def handle_role_change( +    member: discord.Member, +    coro: t.Callable[..., t.Coroutine], +    role: discord.Role +) -> None: +    """ +    Change `member`'s cooldown role via awaiting `coro` and handle errors. + +    `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. +    """ +    try: +        await coro(role) +    except discord.NotFound: +        log.debug(f"Failed to change role for {member} ({member.id}): member not found") +    except discord.Forbidden: +        log.error( +            f"Forbidden to change role for {member} ({member.id}); " +            f"possibly due to role hierarchy" +        ) +    except discord.HTTPException as e: +        log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py index 013ef9e7..188b279f 100644 --- a/bot/utils/pagination.py +++ b/bot/utils/pagination.py @@ -211,8 +211,6 @@ class LinePaginator(Paginator):                  log.debug(f"Got first page reaction - changing to page 1/{len(paginator.pages)}") -                embed.description = "" -                await message.edit(embed=embed)                  embed.description = paginator.pages[current_page]                  if footer_text:                      embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -226,8 +224,6 @@ class LinePaginator(Paginator):                  log.debug(f"Got last page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") -                embed.description = "" -                await message.edit(embed=embed)                  embed.description = paginator.pages[current_page]                  if footer_text:                      embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") @@ -245,8 +241,6 @@ class LinePaginator(Paginator):                  current_page -= 1                  log.debug(f"Got previous page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") -                embed.description = "" -                await message.edit(embed=embed)                  embed.description = paginator.pages[current_page]                  if footer_text: @@ -266,8 +260,6 @@ class LinePaginator(Paginator):                  current_page += 1                  log.debug(f"Got next page reaction - changing to page {current_page + 1}/{len(paginator.pages)}") -                embed.description = "" -                await message.edit(embed=embed)                  embed.description = paginator.pages[current_page]                  if footer_text: @@ -428,8 +420,6 @@ class ImagePaginator(Paginator):                  reaction_type = "next"              # Magic happens here, after page and reaction_type is set -            embed.description = "" -            await message.edit(embed=embed)              embed.description = paginator.pages[current_page]              image = paginator.images[current_page] or EmptyEmbed  |