diff options
author | 2021-10-08 12:07:33 -0700 | |
---|---|---|
committer | 2021-10-08 12:07:33 -0700 | |
commit | 0cc1e8a527216797492be2a5d9875172a4312119 (patch) | |
tree | f8253973da15069bf258c5e7ec3ed2cd2e42a630 | |
parent | Merge pull request #893 from python-discord/bug/candy-collection/fix-typeerro... (diff) | |
parent | Merge branch 'main' into topic-improvements (diff) |
Merge pull request #895 from python-discord/topic-improvements
`.topic` command improvements
-rw-r--r-- | bot/exts/utilities/conversationstarters.py | 52 |
1 files changed, 33 insertions, 19 deletions
diff --git a/bot/exts/utilities/conversationstarters.py b/bot/exts/utilities/conversationstarters.py index 5d62fa83..dcbfe4d5 100644 --- a/bot/exts/utilities/conversationstarters.py +++ b/bot/exts/utilities/conversationstarters.py @@ -2,6 +2,7 @@ import asyncio from contextlib import suppress from functools import partial from pathlib import Path +from typing import Union import discord import yaml @@ -64,35 +65,48 @@ class ConvoStarters(commands.Cog): embed.title = f"**{next(channel_topics)}**" return embed - def _predicate(self, message: discord.Message, reaction: discord.Reaction, user: discord.User) -> bool: - right_reaction = ( - user != self.bot.user - and reaction.message.id == message.id - and str(reaction.emoji) == "🔄" - ) - if not right_reaction: - return False - - is_moderator = any(role.id in MODERATION_ROLES for role in getattr(user, "roles", [])) - if is_moderator or user.id == message.author.id: - return True - - return False - - async def _listen_for_refresh(self, message: discord.Message) -> None: + @staticmethod + def _predicate( + command_invoker: Union[discord.User, discord.Member], + message: discord.Message, + reaction: discord.Reaction, + user: discord.User + ) -> bool: + user_is_moderator = any(role.id in MODERATION_ROLES for role in getattr(user, "roles", [])) + user_is_invoker = user.id == command_invoker.id + + is_right_reaction = all(( + reaction.message.id == message.id, + str(reaction.emoji) == "🔄", + user_is_moderator or user_is_invoker + )) + return is_right_reaction + + async def _listen_for_refresh( + self, + command_invoker: Union[discord.User, discord.Member], + message: discord.Message + ) -> None: await message.add_reaction("🔄") while True: try: reaction, user = await self.bot.wait_for( "reaction_add", - check=partial(self._predicate, message), + check=partial(self._predicate, command_invoker, message), timeout=60.0 ) except asyncio.TimeoutError: with suppress(discord.NotFound): await message.clear_reaction("🔄") - else: + break + + try: await message.edit(embed=self._build_topic_embed(message.channel.id)) + except discord.NotFound: + break + + with suppress(discord.NotFound): + await message.remove_reaction(reaction, user) @commands.command() @commands.cooldown(1, 60*2, commands.BucketType.channel) @@ -104,7 +118,7 @@ class ConvoStarters(commands.Cog): Allows the refresh of a topic by pressing an emoji. """ message = await ctx.send(embed=self._build_topic_embed(ctx.channel.id)) - self.bot.loop.create_task(self._listen_for_refresh(message)) + self.bot.loop.create_task(self._listen_for_refresh(ctx.author, message)) def setup(bot: Bot) -> None: |