aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bot/pagination.py36
-rw-r--r--bot/utils/messages.py66
-rw-r--r--bot/utils/scheduling.py10
3 files changed, 68 insertions, 44 deletions
diff --git a/bot/pagination.py b/bot/pagination.py
index 3b16cc9ff..c5c84afd9 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -2,14 +2,14 @@ import asyncio
import logging
import typing as t
from contextlib import suppress
+from functools import partial
import discord
-from discord import Member
from discord.abc import User
from discord.ext.commands import Context, Paginator
from bot import constants
-from bot.constants import MODERATION_ROLES
+from bot.utils import messages
FIRST_EMOJI = "\u23EE" # [:track_previous:]
LEFT_EMOJI = "\u2B05" # [:arrow_left:]
@@ -220,29 +220,6 @@ class LinePaginator(Paginator):
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
>>> await LinePaginator.paginate([line for line in lines], ctx, embed)
"""
- def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool:
- """Make sure that this reaction is what we want to operate on."""
- no_restrictions = (
- # The reaction was by a whitelisted user
- user_.id == restrict_to_user.id
- # The reaction was by a moderator
- or isinstance(user_, Member) and any(role.id in MODERATION_ROLES for role in user_.roles)
- )
-
- return (
- # Conditions for a successful pagination:
- all((
- # Reaction is on this message
- reaction_.message.id == message.id,
- # Reaction is one of the pagination emotes
- str(reaction_.emoji) in PAGINATION_EMOJI,
- # Reaction was not made by the Bot
- user_.id != ctx.bot.user.id,
- # There were no restrictions
- no_restrictions
- ))
- )
-
paginator = cls(prefix=prefix, suffix=suffix, max_size=max_size, max_lines=max_lines,
scale_to_size=scale_to_size)
current_page = 0
@@ -303,9 +280,16 @@ class LinePaginator(Paginator):
log.trace(f"Adding reaction: {repr(emoji)}")
await message.add_reaction(emoji)
+ check = partial(
+ messages.reaction_check,
+ message_id=message.id,
+ allowed_emoji=PAGINATION_EMOJI,
+ allowed_users=(restrict_to_user.id,),
+ )
+
while True:
try:
- reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=event_check)
+ reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
log.trace(f"Got reaction: {reaction}")
except asyncio.TimeoutError:
log.debug("Timed out waiting for a reaction")
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index 0bcaed43d..2beead6af 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -3,6 +3,7 @@ import contextlib
import logging
import random
import re
+from functools import partial
from io import BytesIO
from typing import List, Optional, Sequence, Union
@@ -12,24 +13,66 @@ from discord.ext.commands import Context
import bot
from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES
+from bot.utils import scheduling
log = logging.getLogger(__name__)
+def reaction_check(
+ reaction: discord.Reaction,
+ user: discord.abc.User,
+ *,
+ message_id: int,
+ allowed_emoji: Sequence[str],
+ allowed_users: Sequence[int],
+ allow_mods: bool = True,
+) -> bool:
+ """
+ Check if a reaction's emoji and author are allowed and the message is `message_id`.
+
+ If the user is not allowed, remove the reaction. Ignore reactions made by the bot.
+ If `allow_mods` is True, allow users with moderator roles even if they're not in `allowed_users`.
+ """
+ right_reaction = (
+ user != bot.instance.user
+ and reaction.message.id == message_id
+ and str(reaction.emoji) in allowed_emoji
+ )
+ if not right_reaction:
+ return False
+
+ is_moderator = (
+ allow_mods
+ and any(role.id in MODERATION_ROLES for role in getattr(user, "roles", []))
+ )
+
+ if user.id in allowed_users or is_moderator:
+ log.trace(f"Allowed reaction {reaction} by {user} on {reaction.message.id}.")
+ return True
+ else:
+ log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.")
+ scheduling.create_task(
+ reaction.message.remove_reaction(reaction.emoji, user),
+ HTTPException, # Suppress the HTTPException if adding the reaction fails
+ name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}"
+ )
+ return False
+
+
async def wait_for_deletion(
message: discord.Message,
- user_ids: Sequence[discord.abc.Snowflake],
+ user_ids: Sequence[int],
deletion_emojis: Sequence[str] = (Emojis.trashcan,),
timeout: float = 60 * 5,
attach_emojis: bool = True,
- allow_moderation_roles: bool = True
+ allow_mods: bool = True
) -> None:
"""
Wait for up to `timeout` seconds for a reaction by any of the specified `user_ids` to delete the message.
An `attach_emojis` bool may be specified to determine whether to attach the given
`deletion_emojis` to the message in the given `context`.
- An `allow_moderation_roles` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete
+ An `allow_mods` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete
the message.
"""
if message.guild is None:
@@ -43,16 +86,13 @@ async def wait_for_deletion(
log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.")
return
- def check(reaction: discord.Reaction, user: discord.Member) -> bool:
- """Check that the deletion emoji is reacted by the appropriate user."""
- return (
- reaction.message.id == message.id
- and str(reaction.emoji) in deletion_emojis
- and (
- user.id in user_ids
- or allow_moderation_roles and any(role.id in MODERATION_ROLES for role in user.roles)
- )
- )
+ check = partial(
+ reaction_check,
+ message_id=message.id,
+ allowed_emoji=deletion_emojis,
+ allowed_users=user_ids,
+ allow_mods=allow_mods,
+ )
with contextlib.suppress(asyncio.TimeoutError):
await bot.instance.wait_for('reaction_add', check=check, timeout=timeout)
diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py
index 6843bae88..2dc485f24 100644
--- a/bot/utils/scheduling.py
+++ b/bot/utils/scheduling.py
@@ -161,18 +161,18 @@ class Scheduler:
self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
-def create_task(*args, **kwargs) -> asyncio.Task:
+def create_task(coro: t.Awaitable, *suppressed_exceptions: t.Type[Exception], **kwargs) -> asyncio.Task:
"""Wrapper for `asyncio.create_task` which logs exceptions raised in the task."""
- task = asyncio.create_task(*args, **kwargs)
- task.add_done_callback(_log_task_exception)
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
return task
-def _log_task_exception(task: asyncio.Task) -> None:
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None:
"""Retrieve and log the exception raised in `task` if one exists."""
with contextlib.suppress(asyncio.CancelledError):
exception = task.exception()
# Log the exception if one exists.
- if exception:
+ if exception and not isinstance(exception, suppressed_exceptions):
log = logging.getLogger(__name__)
log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)