diff options
Diffstat (limited to 'pydis_core')
| -rw-r--r-- | pydis_core/utils/interactions.py | 38 | 
1 files changed, 26 insertions, 12 deletions
| diff --git a/pydis_core/utils/interactions.py b/pydis_core/utils/interactions.py index d3432c3a..1710069d 100644 --- a/pydis_core/utils/interactions.py +++ b/pydis_core/utils/interactions.py @@ -1,7 +1,7 @@  from collections.abc import Sequence  from typing import Literal -from discord import ButtonStyle, HTTPException, Interaction, Message, NotFound, ui +from discord import ButtonStyle, HTTPException, Interaction, Member, Message, NotFound, User, ui  from pydis_core.utils.logging import get_logger  from pydis_core.utils.scheduling import create_task @@ -9,6 +9,25 @@ from pydis_core.utils.scheduling import create_task  log = get_logger(__name__) +def user_has_access( +    user: User | Member, +    *, +    allowed_users: Sequence[int] = (), +    allowed_roles: Sequence[int] = (), +) -> bool: +    """ +    Return whether the user is in the allowed_users list, or has a role from allowed_roles. + +    Args: +        user: The user to check +        allowed_users: A sequence of user ids that are allowed access +        allowed_roles: A sequence of role ids that are allowed access +    """ +    if user.id in allowed_users or any(role.id in allowed_roles for role in getattr(user, "roles", [])): +        return True +    return False + +  async def _handle_modify_message(message: Message, action: Literal["edit", "delete"]) -> None:      """Remove the view from, or delete the given message depending on the specified action."""      try: @@ -60,18 +79,13 @@ class ViewWithUserAndRoleCheck(ui.View):          Args:              interaction: The interaction that occurred.          """ -        if interaction.user.id in self.allowed_users: -            log.trace( -                "Allowed interaction by %s (%d) on %d as they are an allowed user.", -                interaction.user, -                interaction.user.id, -                interaction.message.id, -            ) -            return True - -        if any(role.id in self.allowed_roles for role in getattr(interaction.user, "roles", [])): +        if user_has_access( +            interaction.user, +            allowed_users=self.allowed_users, +            allowed_roles=self.allowed_roles, +        ):              log.trace( -                "Allowed interaction by %s (%d)on %d as they have an allowed role.", +                "Allowed interaction by %s (%d) on %d as they are an allowed user or have an allowed role.",                  interaction.user,                  interaction.user.id,                  interaction.message.id, | 
