aboutsummaryrefslogtreecommitdiffstats
path: root/bot/utils
diff options
context:
space:
mode:
Diffstat (limited to 'bot/utils')
-rw-r--r--bot/utils/checks.py11
-rw-r--r--bot/utils/decorators.py31
-rw-r--r--bot/utils/exceptions.py7
-rw-r--r--bot/utils/halloween/spookifications.py3
-rw-r--r--bot/utils/members.py47
-rw-r--r--bot/utils/pagination.py10
6 files changed, 76 insertions, 33 deletions
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index 612d1ed6..5433f436 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -4,14 +4,7 @@ from collections.abc import Container, Iterable
from typing import Callable, Optional
from discord.ext.commands import (
- BucketType,
- CheckFailure,
- Cog,
- Command,
- CommandOnCooldown,
- Context,
- Cooldown,
- CooldownMapping,
+ BucketType, CheckFailure, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping
)
from bot import constants
@@ -40,7 +33,7 @@ def in_whitelist_check(
channels: Container[int] = (),
categories: Container[int] = (),
roles: Container[int] = (),
- redirect: Optional[int] = constants.Channels.community_bot_commands,
+ redirect: Optional[int] = constants.Channels.sir_lancebot_playground,
fail_silently: bool = False,
) -> bool:
"""
diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py
index 132aaa87..8954e016 100644
--- a/bot/utils/decorators.py
+++ b/bot/utils/decorators.py
@@ -196,15 +196,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
If `whitelist_override` is present, it is added to the global whitelist.
"""
def predicate(ctx: Context) -> bool:
- # Skip DM invocations
- if not ctx.guild:
- log.debug(f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM.")
- return True
-
kwargs = default_kwargs.copy()
+ allow_dms = False
# Update kwargs based on override
if hasattr(ctx.command.callback, "override"):
+ # Handle DM invocations
+ allow_dms = ctx.command.callback.override_dm
+
# Remove default kwargs if reset is True
if ctx.command.callback.override_reset:
kwargs = {}
@@ -234,8 +233,12 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
f"invoked by {ctx.author}."
)
- log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
- result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
+ if ctx.guild is None:
+ log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.")
+ result = allow_dms
+ else:
+ log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.")
+ result = in_whitelist_check(ctx, fail_silently=True, **kwargs)
# Return if check passed
if result:
@@ -254,14 +257,14 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
channels = set(kwargs.get("channels") or {})
categories = kwargs.get("categories")
- # Only output override channels + community_bot_commands
+ # Only output override channels + sir_lancebot_playground
if channels:
default_whitelist_channels = set(WHITELISTED_CHANNELS)
- default_whitelist_channels.discard(Channels.community_bot_commands)
+ default_whitelist_channels.discard(Channels.sir_lancebot_playground)
channels.difference_update(default_whitelist_channels)
- # Add all whitelisted category channels
- if categories:
+ # Add all whitelisted category channels, but skip if we're in DMs
+ if categories and ctx.guild is not None:
for category_id in categories:
category = ctx.guild.get_channel(category_id)
if category is None:
@@ -280,18 +283,22 @@ def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], boo
return predicate
-def whitelist_override(bypass_defaults: bool = False, **kwargs: Container[int]) -> Callable:
+def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable:
"""
Override global whitelist context, with the kwargs specified.
All arguments from `in_whitelist_check` are supported, with the exception of `fail_silently`.
Set `bypass_defaults` to True if you want to completely bypass global checks.
+ Set `allow_dm` to True if you want to allow the command to be invoked from within direct messages.
+ Note that you have to be careful with any references to the guild.
+
This decorator has to go before (below) below the `command` decorator.
"""
def inner(func: Callable) -> Callable:
func.override = kwargs
func.override_reset = bypass_defaults
+ func.override_dm = allow_dm
return func
return inner
diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py
index bf0e5813..3cd96325 100644
--- a/bot/utils/exceptions.py
+++ b/bot/utils/exceptions.py
@@ -15,3 +15,10 @@ class APIError(Exception):
self.api = api
self.status_code = status_code
self.error_msg = error_msg
+
+
+class MovedCommandError(Exception):
+ """Raised when a command has moved locations."""
+
+ def __init__(self, new_command_name: str):
+ self.new_command_name = new_command_name
diff --git a/bot/utils/halloween/spookifications.py b/bot/utils/halloween/spookifications.py
index 93c5ddb9..c45ef8dc 100644
--- a/bot/utils/halloween/spookifications.py
+++ b/bot/utils/halloween/spookifications.py
@@ -1,8 +1,7 @@
import logging
from random import choice, randint
-from PIL import Image
-from PIL import ImageOps
+from PIL import Image, ImageOps
log = logging.getLogger()
diff --git a/bot/utils/members.py b/bot/utils/members.py
new file mode 100644
index 00000000..de5850ca
--- /dev/null
+++ b/bot/utils/members.py
@@ -0,0 +1,47 @@
+import logging
+import typing as t
+
+import discord
+
+log = logging.getLogger(__name__)
+
+
+async def get_or_fetch_member(guild: discord.Guild, member_id: int) -> t.Optional[discord.Member]:
+ """
+ Attempt to get a member from cache; on failure fetch from the API.
+
+ Return `None` to indicate the member could not be found.
+ """
+ if member := guild.get_member(member_id):
+ log.trace("%s retrieved from cache.", member)
+ else:
+ try:
+ member = await guild.fetch_member(member_id)
+ except discord.errors.NotFound:
+ log.trace("Failed to fetch %d from API.", member_id)
+ return None
+ log.trace("%s fetched from API.", member)
+ return member
+
+
+async def handle_role_change(
+ member: discord.Member,
+ coro: t.Callable[..., t.Coroutine],
+ role: discord.Role
+) -> None:
+ """
+ Change `member`'s cooldown role via awaiting `coro` and handle errors.
+
+ `coro` is intended to be `discord.Member.add_roles` or `discord.Member.remove_roles`.
+ """
+ try:
+ await coro(role)
+ except discord.NotFound:
+ log.debug(f"Failed to change role for {member} ({member.id}): member not found")
+ except discord.Forbidden:
+ log.error(
+ f"Forbidden to change role for {member} ({member.id}); "
+ f"possibly due to role hierarchy"
+ )
+ except discord.HTTPException as e:
+ log.error(f"Failed to change role for {member} ({member.id}): {e.status} {e.code}")
diff --git a/bot/utils/pagination.py b/bot/utils/pagination.py
index 013ef9e7..188b279f 100644
--- a/bot/utils/pagination.py
+++ b/bot/utils/pagination.py
@@ -211,8 +211,6 @@ class LinePaginator(Paginator):
log.debug(f"Got first page reaction - changing to page 1/{len(paginator.pages)}")
- embed.description = ""
- await message.edit(embed=embed)
embed.description = paginator.pages[current_page]
if footer_text:
embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
@@ -226,8 +224,6 @@ class LinePaginator(Paginator):
log.debug(f"Got last page reaction - changing to page {current_page + 1}/{len(paginator.pages)}")
- embed.description = ""
- await message.edit(embed=embed)
embed.description = paginator.pages[current_page]
if footer_text:
embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})")
@@ -245,8 +241,6 @@ class LinePaginator(Paginator):
current_page -= 1
log.debug(f"Got previous page reaction - changing to page {current_page + 1}/{len(paginator.pages)}")
- embed.description = ""
- await message.edit(embed=embed)
embed.description = paginator.pages[current_page]
if footer_text:
@@ -266,8 +260,6 @@ class LinePaginator(Paginator):
current_page += 1
log.debug(f"Got next page reaction - changing to page {current_page + 1}/{len(paginator.pages)}")
- embed.description = ""
- await message.edit(embed=embed)
embed.description = paginator.pages[current_page]
if footer_text:
@@ -428,8 +420,6 @@ class ImagePaginator(Paginator):
reaction_type = "next"
# Magic happens here, after page and reaction_type is set
- embed.description = ""
- await message.edit(embed=embed)
embed.description = paginator.pages[current_page]
image = paginator.images[current_page] or EmptyEmbed