diff options
author | 2022-02-14 23:41:14 +0000 | |
---|---|---|
committer | 2022-02-14 23:41:14 +0000 | |
commit | 657aaf5f613b0bd219d67c71f7dc37ede0bb6aab (patch) | |
tree | 548e48ae21a98d5e6e16e8619d1af4ee75fe8b7a /bot/utils | |
parent | fix: Add newlines in codeblock formatting (diff) | |
parent | Merge pull request #1026 from Shom770/fix-trivia-night (diff) |
Merge branch 'main' into merge-github-issues
Diffstat (limited to 'bot/utils')
-rw-r--r-- | bot/utils/decorators.py | 27 | ||||
-rw-r--r-- | bot/utils/exceptions.py | 7 | ||||
-rw-r--r-- | bot/utils/members.py | 47 |
3 files changed, 71 insertions, 10 deletions
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 132aaa87..7a3b14ad 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -196,15 +196,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo If `whitelist_override` is present, it is added to the global whitelist. """ def predicate(ctx: Context) -> bool: - # Skip DM invocations - if not ctx.guild: - log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.") - return True - kwargs = default_kwargs.copy() + allow_dms = False # Update kwargs based on override if hasattr(ctx.command.callback, "override"): + # Handle DM invocations + allow_dms = ctx.command.callback.override_dm + # Remove default kwargs if reset is True if ctx.command.callback.override_reset: kwargs = {} @@ -234,8 +233,12 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo f"invoked by {ctx.author}." ) - log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") - result = in_whitelist_check(ctx, fail_silently=True, **kwargs) + if ctx.guild is None: + log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.") + result = allow_dms + else: + log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") + result = in_whitelist_check(ctx, fail_silently=True, **kwargs) # Return if check passed if result: @@ -260,8 +263,8 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo default_whitelist_channels.discard(Channels.community_bot_commands) channels.difference_update(default_whitelist_channels) - # Add all whitelisted category channels - if categories: + # Add all whitelisted category channels, but skip if we're in DMs + if categories and ctx.guild is not None: for category_id in categories: category = ctx.guild.get_channel(category_id) if category is None: @@ -280,18 +283,22 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo return predicate -def whitelist_override(bypass_defaults: bool = False, **kwargs: Container[int]) -> Callable: +def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable: """ Override global whitelist context, with the kwargs specified. All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`. Set `bypass_defaults` to True if you want to completely bypass global checks. + Set `allow_dm` to True if you want to allow the command to be invoked from within direct messages. + Note that you have to be careful with any references to the guild. + This decorator has to go before (below) below the `command` decorator. """ def inner(func: Callable) -> Callable: func.override = kwargs func.override_reset = bypass_defaults + func.override_dm = allow_dm return func return inner diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index bf0e5813..3cd96325 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -15,3 +15,10 @@ class APIError(Exception): self.api = api self.status_code = status_code self.error_msg = error_msg + + +class MovedCommandError(Exception): + """Raised when a command has moved locations.""" + + def __init__(self, new_command_name: str): + self.new_command_name = new_command_name diff --git a/bot/utils/members.py b/bot/utils/members.py new file mode 100644 index 00000000..de5850ca --- /dev/null +++ b/bot/utils/members.py @@ -0,0 +1,47 @@ +import logging +import typing as t + +import discord + +log = logging.getLogger(__name__) + + +async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]: + """ + Attempt to get a member from cache; on failure fetch from the API. + + Return `None` to indicate the member could not be found. + """ + if member := guild.get_member(member_id): + log.trace("%s retrieved from cache.", member) + else: + try: + member = await guild.fetch_member(member_id) + except discord.errors.NotFound: + log.trace("Failed to fetch %d from API.", member_id) + return None + log.trace("%s fetched from API.", member) + return member + + +async def handle_role_change( + member: discord.Member, + coro: t.Callable[..., t.Coroutine], + role: discord.Role +) -> None: + """ + Change `member`'s cooldown role via awaiting `coro` and handle errors. + + `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`. + """ + try: + await coro(role) + except discord.NotFound: + log.debug(f"Failed to change role for {member} ({member.id}): member not found") + except discord.Forbidden: + log.error( + f"Forbidden to change role for {member} ({member.id}); " + f"possibly due to role hierarchy" + ) + except discord.HTTPException as e: + log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}") |