aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_core/utils
diff options
context:
space:
mode:
authorGravatar shtlrs <[email protected]>2023-12-07 17:52:08 +0100
committerGravatar Chris Lovering <[email protected]>2024-01-30 19:07:36 +0000
commitb457c9b9cfad0d6205b2b8e53be2e98d0db59677 (patch)
tree094142591bc3fccf059b27f2af5942d6198c281f /pydis_core/utils
parentport pagination tests (diff)
port all checks from sir-lancebot and bot
Diffstat (limited to 'pydis_core/utils')
-rw-r--r--pydis_core/utils/checks.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/pydis_core/utils/checks.py b/pydis_core/utils/checks.py
new file mode 100644
index 00000000..809e98de
--- /dev/null
+++ b/pydis_core/utils/checks.py
@@ -0,0 +1,187 @@
+import datetime
+from collections.abc import Callable, Container, Iterable
+
+from discord.ext.commands import (
+ BucketType,
+ CheckFailure,
+ Cog,
+ Command,
+ CommandOnCooldown,
+ Context,
+ Cooldown,
+ CooldownMapping,
+)
+
+from pydis_core.utils.logging import get_logger
+
+log = get_logger(__name__)
+
+
+class ContextCheckFailure(CheckFailure):
+ """Raised when a context-specific check fails."""
+
+ def __init__(self, redirect_channel: int | None) -> None:
+ self.redirect_channel = redirect_channel
+
+ if redirect_channel:
+ redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
+ else:
+ redirect_message = ""
+
+ error_message = f"You are not allowed to use that command{redirect_message}."
+
+ super().__init__(error_message)
+
+
+class InWhitelistCheckFailure(ContextCheckFailure):
+ """Raised when the `in_whitelist` check fails."""
+
+
+def in_whitelist_check(
+ ctx: Context,
+ redirect: int,
+ channels: Container[int] = (),
+ categories: Container[int] = (),
+ roles: Container[int] = (),
+ fail_silently: bool = False,
+) -> bool:
+ """
+ Check if a command was issued in a whitelisted context.
+
+ The whitelists that can be provided are:
+
+ - `channels`: a container with channel ids for whitelisted channels
+ - `categories`: a container with category ids for whitelisted categories
+ - `roles`: a container with with role ids for whitelisted roles
+
+ If the command was invoked in a context that was not whitelisted, the member is either
+ redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
+ told that they're not allowed to use this particular command (if `None` was passed).
+ """
+ if redirect not in channels:
+ # It does not make sense for the channel whitelist to not contain the redirection
+ # channel (if applicable). That's why we add the redirection channel to the `channels`
+ # container if it's not already in it. As we allow any container type to be passed,
+ # we first create a tuple in order to safely add the redirection channel.
+ #
+ # Note: It's possible for the redirect channel to be in a whitelisted category, but
+ # there's no easy way to check that and as a channel can easily be moved in and out of
+ # categories, it's probably not wise to rely on its category in any case.
+ channels = tuple(channels) + (redirect,)
+
+ if channels and ctx.channel.id in channels:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
+ return True
+
+ # Only check the category id if we have a category whitelist and the channel has a `category_id`
+ if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
+ return True
+
+ # Only check the roles whitelist if we have one and ensure the author's roles attribute returns
+ # an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
+ if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
+ log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
+ return True
+
+ log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
+
+ # Some commands are secret, and should produce no feedback at all.
+ if not fail_silently:
+ raise InWhitelistCheckFailure(redirect)
+ return False
+
+
+def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
+ bypass_roles: Iterable[int]) -> Callable:
+ """
+ Applies a cooldown to a command, but allows members with certain roles to be ignored.
+
+ NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future.
+ """
+ # Make it a set so lookup is hash based.
+ bypass = set(bypass_roles)
+
+ # This handles the actual cooldown logic.
+ buckets = CooldownMapping(Cooldown(rate, per, type))
+
+ # Will be called after the command has been parse but before it has been invoked, ensures that
+ # the cooldown won't be updated if the user screws up their input to the command.
+ async def predicate(cog: Cog, ctx: Context) -> None:
+ nonlocal bypass, buckets
+
+ if any(role.id in bypass for role in ctx.author.roles):
+ return
+
+ # Cooldown logic, taken from discord.py internals.
+ current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp()
+ bucket = buckets.get_bucket(ctx.message)
+ retry_after = bucket.update_rate_limit(current)
+ if retry_after:
+ raise CommandOnCooldown(bucket, retry_after)
+
+ def wrapper(command: Command) -> Command:
+ # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it
+ # so I just made it raise an error when the decorator is applied before the actual command object exists.
+ #
+ # If the `before_invoke` detail is ever a problem then I can quickly just swap over.
+ if not isinstance(command, Command):
+ raise TypeError(
+ "Decorator `cooldown_with_role_bypass` must be applied after the command decorator. "
+ "This means it has to be above the command decorator in the code."
+ )
+
+ command._before_invoke = predicate
+
+ return command
+
+ return wrapper
+
+
+async def has_any_role_check(ctx: Context, *roles: str | int) -> bool:
+ """
+ Returns True if the context's author has any of the specified roles.
+
+ `roles` are the names or IDs of the roles for which to check.
+ False is always returns if the context is outside a guild.
+ """
+ if not ctx.guild: # Return False in a DM
+ log.trace(
+ f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
+ "This command is restricted by the with_role decorator. Rejecting request."
+ )
+ return False
+
+ for role in ctx.author.roles:
+ if role.id in roles:
+ log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
+ return True
+
+ log.trace(
+ f"{ctx.author} does not have the required role to use "
+ f"the '{ctx.command.name}' command, so the request is rejected."
+ )
+ return False
+
+
+async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool:
+ """
+ Returns True if the context's author doesn't have any of the specified roles.
+
+ `roles` are the names or IDs of the roles for which to check.
+ False is always returns if the context is outside a guild.
+ """
+ if not ctx.guild: # Return False in a DM
+ log.trace(
+ f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
+ "This command is restricted by the without_role decorator. Rejecting request."
+ )
+ return False
+
+ author_roles = [role.id for role in ctx.author.roles]
+ check = all(role not in author_roles for role in roles)
+ log.trace(
+ f"{ctx.author} tried to call the '{ctx.command.name}' command. "
+ f"The result of the without_role check was {check}."
+ )
+ return check