diff options
| -rw-r--r-- | bot/pagination.py | 36 | ||||
| -rw-r--r-- | bot/utils/messages.py | 66 | ||||
| -rw-r--r-- | bot/utils/scheduling.py | 10 |
3 files changed, 68 insertions, 44 deletions
diff --git a/bot/pagination.py b/bot/pagination.py index 3b16cc9ff..c5c84afd9 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -2,14 +2,14 @@ import asyncio import logging import typing as t from contextlib import suppress +from functools import partial import discord -from discord import Member from discord.abc import User from discord.ext.commands import Context, Paginator from bot import constants -from bot.constants import MODERATION_ROLES +from bot.utils import messages FIRST_EMOJI = "\u23EE" # [:track_previous:] LEFT_EMOJI = "\u2B05" # [:arrow_left:] @@ -220,29 +220,6 @@ class LinePaginator(Paginator): >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ - def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool: - """Make sure that this reaction is what we want to operate on.""" - no_restrictions = ( - # The reaction was by a whitelisted user - user_.id == restrict_to_user.id - # The reaction was by a moderator - or isinstance(user_, Member) and any(role.id in MODERATION_ROLES for role in user_.roles) - ) - - return ( - # Conditions for a successful pagination: - all(( - # Reaction is on this message - reaction_.message.id == message.id, - # Reaction is one of the pagination emotes - str(reaction_.emoji) in PAGINATION_EMOJI, - # Reaction was not made by the Bot - user_.id != ctx.bot.user.id, - # There were no restrictions - no_restrictions - )) - ) - paginator = cls(prefix=prefix, suffix=suffix, max_size=max_size, max_lines=max_lines, scale_to_size=scale_to_size) current_page = 0 @@ -303,9 +280,16 @@ class LinePaginator(Paginator): log.trace(f"Adding reaction: {repr(emoji)}") await message.add_reaction(emoji) + check = partial( + messages.reaction_check, + message_id=message.id, + allowed_emoji=PAGINATION_EMOJI, + allowed_users=(restrict_to_user.id,), + ) + while True: try: - reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=event_check) + reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check) log.trace(f"Got reaction: {reaction}") except asyncio.TimeoutError: log.debug("Timed out waiting for a reaction") diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 0bcaed43d..2beead6af 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -3,6 +3,7 @@ import contextlib import logging import random import re +from functools import partial from io import BytesIO from typing import List, Optional, Sequence, Union @@ -12,24 +13,66 @@ from discord.ext.commands import Context import bot from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES +from bot.utils import scheduling log = logging.getLogger(__name__) +def reaction_check( + reaction: discord.Reaction, + user: discord.abc.User, + *, + message_id: int, + allowed_emoji: Sequence[str], + allowed_users: Sequence[int], + allow_mods: bool = True, +) -> bool: + """ + Check if a reaction's emoji and author are allowed and the message is `message_id`. + + If the user is not allowed, remove the reaction. Ignore reactions made by the bot. + If `allow_mods` is True, allow users with moderator roles even if they're not in `allowed_users`. + """ + right_reaction = ( + user != bot.instance.user + and reaction.message.id == message_id + and str(reaction.emoji) in allowed_emoji + ) + if not right_reaction: + return False + + is_moderator = ( + allow_mods + and any(role.id in MODERATION_ROLES for role in getattr(user, "roles", [])) + ) + + if user.id in allowed_users or is_moderator: + log.trace(f"Allowed reaction {reaction} by {user} on {reaction.message.id}.") + return True + else: + log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.") + scheduling.create_task( + reaction.message.remove_reaction(reaction.emoji, user), + HTTPException, # Suppress the HTTPException if adding the reaction fails + name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}" + ) + return False + + async def wait_for_deletion( message: discord.Message, - user_ids: Sequence[discord.abc.Snowflake], + user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, attach_emojis: bool = True, - allow_moderation_roles: bool = True + allow_mods: bool = True ) -> None: """ Wait for up to `timeout` seconds for a reaction by any of the specified `user_ids` to delete the message. An `attach_emojis` bool may be specified to determine whether to attach the given `deletion_emojis` to the message in the given `context`. - An `allow_moderation_roles` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete + An `allow_mods` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete the message. """ if message.guild is None: @@ -43,16 +86,13 @@ async def wait_for_deletion( log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.") return - def check(reaction: discord.Reaction, user: discord.Member) -> bool: - """Check that the deletion emoji is reacted by the appropriate user.""" - return ( - reaction.message.id == message.id - and str(reaction.emoji) in deletion_emojis - and ( - user.id in user_ids - or allow_moderation_roles and any(role.id in MODERATION_ROLES for role in user.roles) - ) - ) + check = partial( + reaction_check, + message_id=message.id, + allowed_emoji=deletion_emojis, + allowed_users=user_ids, + allow_mods=allow_mods, + ) with contextlib.suppress(asyncio.TimeoutError): await bot.instance.wait_for('reaction_add', check=check, timeout=timeout) diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 6843bae88..2dc485f24 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -161,18 +161,18 @@ class Scheduler: self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) -def create_task(*args, **kwargs) -> asyncio.Task: +def create_task(coro: t.Awaitable, *suppressed_exceptions: t.Type[Exception], **kwargs) -> asyncio.Task: """Wrapper for `asyncio.create_task` which logs exceptions raised in the task.""" - task = asyncio.create_task(*args, **kwargs) - task.add_done_callback(_log_task_exception) + task = asyncio.create_task(coro, **kwargs) + task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions)) return task -def _log_task_exception(task: asyncio.Task) -> None: +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None: """Retrieve and log the exception raised in `task` if one exists.""" with contextlib.suppress(asyncio.CancelledError): exception = task.exception() # Log the exception if one exists. - if exception: + if exception and not isinstance(exception, suppressed_exceptions): log = logging.getLogger(__name__) log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) |