diff options
45 files changed, 2242 insertions, 946 deletions
diff --git a/.gitignore b/.gitignore index fb3156ab1..2074887ad 100644 --- a/.gitignore +++ b/.gitignore @@ -110,6 +110,7 @@ ENV/  # Logfiles  log.* +*.log.*  # Custom user configuration  config.yml diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY new file mode 100644 index 000000000..eacd9b952 --- /dev/null +++ b/LICENSE-THIRD-PARTY @@ -0,0 +1,88 @@ +--------------------------------------------------------------------------------------------------- +                                       BSD 3-Clause License +Applies to: +    - Copyright (c) 2008-Present, IPython Development Team +      Copyright (c) 2001-2007, Fernando Perez <[email protected]> +      Copyright (c) 2001, Janko Hauser <[email protected]> +      Copyright (c) 2001, Nathaniel Gray <[email protected]> +      All rights reserved. +        - bot/exts/info/codeblock/_parsing.py: _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL +--------------------------------------------------------------------------------------------------- + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this +  list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, +  this list of conditions and the following disclaimer in the documentation +  and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its +  contributors may be used to endorse or promote products derived from +  this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--------------------------------------------------------------------------------------------------- +                           PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +Applies to: +    - Copyright © 2001-2020 Python Software Foundation. All rights reserved. +        - tests/_autospec.py: _decoration_helper +--------------------------------------------------------------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; +All Rights Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis.  PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED.  BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee.  This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. @@ -14,7 +14,7 @@ beautifulsoup4 = "~=4.9"  colorama = {version = "~=0.4.3",sys_platform = "== 'win32'"}  coloredlogs = "~=14.0"  deepdiff = "~=4.0" -discord.py = "~=1.4.0" +"discord.py" = "~=1.5.0"  feedparser = "~=5.2"  fuzzywuzzy = "~=0.17"  lxml = "~=4.4" diff --git a/Pipfile.lock b/Pipfile.lock index 4c63277de..becd85c55 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@  {      "_meta": {          "hash": { -            "sha256": "644012a1c3fa3e3a30f8b8f8e672c468dfaa155d9e43d26e2be8713c8dc5ebb3" +            "sha256": "073fd0c51749aafa188fdbe96c5b90dd157cb1d23bdd144801fb0d0a369ffa88"          },          "pipfile-spec": 6,          "requires": { @@ -18,11 +18,11 @@      "default": {          "aio-pika": {              "hashes": [ -                "sha256:4a20d4d941e1f113a950ea529a90bd9159c8d7aafaa1c71e9c707c8c2b526ea6", -                "sha256:7bf3f183df1eb348d007210a0c1a3c5c755f1b3def1a9a395e93f30b91da1daf" +                "sha256:9773440a89840941ac3099a7720bf9d51e8764a484066b82ede4d395660ff430", +                "sha256:a8065be3c722eb8f9fff8c0e7590729e7782202cdb9363d9830d7d5d47b45c7c"              ],              "index": "pypi", -            "version": "==6.7.0" +            "version": "==6.7.1"          },          "aiodns": {              "hashes": [ @@ -205,22 +205,13 @@              "index": "pypi",              "version": "==4.3.2"          }, -        "discord": { -            "hashes": [ -                "sha256:9d4debb4a37845543bd4b92cb195bc53a302797333e768e70344222857ff1559", -                "sha256:ff6653655e342e7721dfb3f10421345fd852c2a33f2cca912b1c39b3778a9429" -            ], -            "index": "pypi", -            "py": "~=1.4.0", -            "version": "==1.0.1" -        },          "discord.py": {              "hashes": [ -                "sha256:98ea3096a3585c9c379209926f530808f5fcf4930928d8cfb579d2562d119570", -                "sha256:f9decb3bfa94613d922376288617e6a6f969260923643e2897f4540c34793442" +                "sha256:3acb61fde0d862ed346a191d69c46021e6063673f63963bc984ae09a685ab211", +                "sha256:e71089886aa157341644bdecad63a72ff56b44406b1a6467b66db31c8e5a5a15"              ], -            "markers": "python_full_version >= '3.5.3'", -            "version": "==1.4.1" +            "index": "pypi", +            "version": "==1.5.0"          },          "docutils": {              "hashes": [ @@ -1,6 +1,6 @@  # Python Utility Bot -[](https://discord.gg/2B963hn) +[](https://discord.gg/2B963hn)  [](https://dev.azure.com/python-discord/Python%20Discord/_build/latest?definitionId=1&branchName=master)  [](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)  [](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master) diff --git a/bot/__main__.py b/bot/__main__.py index 152ddbf92..367be1300 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -47,14 +47,22 @@ loop.run_until_complete(redis_session.connect())  # Instantiate the bot.  allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] +intents = discord.Intents().all() +intents.presences = False +intents.dm_typing = False +intents.dm_reactions = False +intents.invites = False +intents.webhooks = False +intents.integrations = False  bot = Bot(      redis_session=redis_session,      loop=loop,      command_prefix=when_mentioned_or(constants.Bot.prefix), -    activity=discord.Game(name="Commands: !help"), +    activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),      case_insensitive=True,      max_messages=10_000, -    allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles) +    allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), +    intents=intents,  )  # Load extensions. diff --git a/bot/constants.py b/bot/constants.py index c21fd52e0..23d5b4304 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -377,6 +377,7 @@ class Categories(metaclass=YAMLGetter):      help_in_use: int      help_dormant: int      modmail: int +    voice: int  class Channels(metaclass=YAMLGetter): @@ -392,6 +393,7 @@ class Channels(metaclass=YAMLGetter):      bot_commands: int      change_log: int      code_help_voice: int +    code_help_voice_2: int      cooldown: int      defcon: int      dev_contrib: int @@ -424,6 +426,8 @@ class Channels(metaclass=YAMLGetter):      user_event_announcements: int      user_log: int      verification: int +    voice_chat: int +    voice_gate: int      voice_log: int @@ -456,9 +460,11 @@ class Roles(metaclass=YAMLGetter):      owners: int      partners: int      python_community: int +    sprinters: int      team_leaders: int      unverified: int      verified: int  # This is the Developers role on PyDis, here named verified for readability reasons. +    voice_verified: int  class Guild(metaclass=YAMLGetter): @@ -467,6 +473,7 @@ class Guild(metaclass=YAMLGetter):      id: int      invite: str  # Discord invite, gets embedded in chat      moderation_channels: List[int] +    moderation_categories: List[int]      moderation_roles: List[int]      modlog_blacklist: List[int]      reminder_whitelist: List[int] @@ -528,6 +535,15 @@ class BigBrother(metaclass=YAMLGetter):      header_message_limit: int +class CodeBlock(metaclass=YAMLGetter): +    section = 'code_block' + +    channel_whitelist: List[int] +    cooldown_channels: List[int] +    cooldown_seconds: int +    minimum_lines: int + +  class Free(metaclass=YAMLGetter):      section = 'free' @@ -578,6 +594,14 @@ class Verification(metaclass=YAMLGetter):      kick_confirmation_threshold: float +class VoiceGate(metaclass=YAMLGetter): +    section = "voice_gate" + +    minimum_days_verified: int +    minimum_messages: int +    bot_message_delete_delay: int + +  class Event(Enum):      """      Event names. This does not include every event (for example, raw @@ -618,6 +642,9 @@ STAFF_ROLES = Guild.staff_roles  # Channel combinations  MODERATION_CHANNELS = Guild.moderation_channels +# Category combinations +MODERATION_CATEGORIES = Guild.moderation_categories +  # Bot replies  NEGATIVE_REPLIES = [      "Noooooo!!", diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 3d4a09df3..38468c2b1 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -14,7 +14,6 @@ log = logging.getLogger(__name__)  # These objects are declared as namedtuples because tuples are hashable,  # something that we make use of when diffing site roles against guild roles.  _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild'))  _Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) @@ -134,61 +133,76 @@ class UserSyncer(Syncer):      async def _get_diff(self, guild: Guild) -> _Diff:          """Return the difference of users between the cache of `guild` and the database."""          log.trace("Getting the diff for users.") -        users = await self.bot.api_client.get('bot/users') -        # Pack DB roles and guild roles into one common, hashable format. -        # They're hashable so that they're easily comparable with sets later. -        db_users = { -            user_dict['id']: _User( -                roles=tuple(sorted(user_dict.pop('roles'))), -                **user_dict -            ) -            for user_dict in users -        } -        guild_users = { -            member.id: _User( -                id=member.id, -                name=member.name, -                discriminator=int(member.discriminator), -                roles=tuple(sorted(role.id for role in member.roles)), -                in_guild=True -            ) -            for member in guild.members -        } +        users_to_create = [] +        users_to_update = [] +        seen_guild_users = set() + +        async for db_user in self._get_users(): +            # Store user fields which are to be updated. +            updated_fields = {} -        users_to_create = set() -        users_to_update = set() +            def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None: +                # Equalize DB user and guild user attributes. +                if db_user[db_field] != guild_value: +                    updated_fields[db_field] = guild_value -        for db_user in db_users.values(): -            guild_user = guild_users.get(db_user.id) -            if guild_user is not None: -                if db_user != guild_user: -                    users_to_update.add(guild_user) +            if guild_user := guild.get_member(db_user["id"]): +                seen_guild_users.add(guild_user.id) -            elif db_user.in_guild: +                maybe_update("name", guild_user.name) +                maybe_update("discriminator", int(guild_user.discriminator)) +                maybe_update("in_guild", True) + +                guild_roles = [role.id for role in guild_user.roles] +                if set(db_user["roles"]) != set(guild_roles): +                    updated_fields["roles"] = guild_roles + +            elif db_user["in_guild"]:                  # The user is known in the DB but not the guild, and the                  # DB currently specifies that the user is a member of the guild.                  # This means that the user has left since the last sync.                  # Update the `in_guild` attribute of the user on the site                  # to signify that the user left. -                new_api_user = db_user._replace(in_guild=False) -                users_to_update.add(new_api_user) - -        new_user_ids = set(guild_users.keys()) - set(db_users.keys()) -        for user_id in new_user_ids: -            # The user is known on the guild but not on the API. This means -            # that the user has joined since the last sync. Create it. -            new_user = guild_users[user_id] -            users_to_create.add(new_user) +                updated_fields["in_guild"] = False + +            if updated_fields: +                updated_fields["id"] = db_user["id"] +                users_to_update.append(updated_fields) + +        for member in guild.members: +            if member.id not in seen_guild_users: +                # The user is known on the guild but not on the API. This means +                # that the user has joined since the last sync. Create it. +                new_user = { +                    "id": member.id, +                    "name": member.name, +                    "discriminator": int(member.discriminator), +                    "roles": [role.id for role in member.roles], +                    "in_guild": True +                } +                users_to_create.append(new_user)          return _Diff(users_to_create, users_to_update, None) +    async def _get_users(self) -> t.AsyncIterable: +        """GET users from database.""" +        query_params = { +            "page": 1 +        } +        while query_params["page"]: +            res = await self.bot.api_client.get("bot/users", params=query_params) +            for user in res["results"]: +                yield user + +            query_params["page"] = res["next_page_no"] +      async def _sync(self, diff: _Diff) -> None:          """Synchronise the database with the user cache of `guild`."""          log.trace("Syncing created users...") -        for user in diff.created: -            await self.bot.api_client.post('bot/users', json=user._asdict()) +        if diff.created: +            await self.bot.api_client.post("bot/users", json=diff.created)          log.trace("Syncing updated users...") -        for user in diff.updated: -            await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) +        if diff.updated: +            await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index 7894ec48f..26f00e91f 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound  from discord.ext.commands import Cog  from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs +from bot.constants import Channels, Filter, URLs  log = logging.getLogger(__name__) @@ -61,7 +61,7 @@ class AntiMalware(Cog):          # Check if user is staff, if is, return          # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance -        if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): +        if hasattr(message.author, "roles") and any(role.id in Filter.role_whitelist for role in message.author.roles):              return          embed = Embed() diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 4964283f1..af8528a68 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -15,7 +15,6 @@ from bot.constants import (      AntiSpam as AntiSpamConfig, Channels,      Colours, DEBUG_MODE, Event, Filter,      Guild as GuildConfig, Icons, -    STAFF_ROLES,  )  from bot.converters import Duration  from bot.exts.moderation.modlog import ModLog @@ -149,7 +148,7 @@ class AntiSpam(Cog):              or message.guild.id != GuildConfig.id              or message.author.bot              or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) -            or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) +            or (any(role.id in Filter.role_whitelist for role in message.author.roles) and not DEBUG_MODE)          ):              return diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index 82084ea88..48aa2749c 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -22,6 +22,7 @@ class DuckPond(Cog):          self.bot = bot          self.webhook_id = constants.Webhooks.duck_pond          self.webhook = None +        self.ducked_messages = []          self.bot.loop.create_task(self.fetch_webhook())          self.relay_lock = None @@ -176,7 +177,8 @@ class DuckPond(Cog):          duck_count = await self.count_ducks(message)          # If we've got more than the required amount of ducks, send the message to the duck_pond. -        if duck_count >= constants.DuckPond.threshold: +        if duck_count >= constants.DuckPond.threshold and message.id not in self.ducked_messages: +            self.ducked_messages.append(message.id)              await self.locked_relay(message)      @Cog.listener() diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index f5c9a5dd0..062d4fcfe 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -14,6 +14,7 @@ from discord.ext import commands  from bot import constants  from bot.bot import Bot +from bot.utils import channel as channel_utils  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) @@ -378,11 +379,18 @@ class HelpChannels(commands.Cog):          log.trace("Getting the CategoryChannel objects for the help categories.")          try: -            self.available_category = await self.try_get_channel( -                constants.Categories.help_available +            self.available_category = await channel_utils.try_get_channel( +                constants.Categories.help_available, +                self.bot +            ) +            self.in_use_category = await channel_utils.try_get_channel( +                constants.Categories.help_in_use, +                self.bot +            ) +            self.dormant_category = await channel_utils.try_get_channel( +                constants.Categories.help_dormant, +                self.bot              ) -            self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) -            self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant)          except discord.HTTPException:              log.exception("Failed to get a category; cog will be removed")              self.bot.remove_cog(self.qualified_name) @@ -442,12 +450,6 @@ class HelpChannels(commands.Cog):              return False          return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() -    @staticmethod -    def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: -        """Return True if `channel` is within a category with `category_id`.""" -        actual_category = getattr(channel, "category", None) -        return actual_category is not None and actual_category.id == category_id -      async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None:          """          Make the `channel` dormant if idle or schedule the move if still active. @@ -498,7 +500,7 @@ class HelpChannels(commands.Cog):          options should be avoided, as it may interfere with the category move we perform.          """          # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. -        category = await self.try_get_channel(category_id) +        category = await channel_utils.try_get_channel(category_id, self.bot)          payload = [{"id": c.id, "position": c.position} for c in category.channels] @@ -646,7 +648,7 @@ class HelpChannels(commands.Cog):          channel = message.channel          # Confirm the channel is an in use help channel -        if self.is_in_category(channel, constants.Categories.help_in_use): +        if channel_utils.is_in_category(channel, constants.Categories.help_in_use):              log.trace(f"Checking if #{channel} ({channel.id}) has been answered.")              # Check if there is an entry in unanswered @@ -671,7 +673,8 @@ class HelpChannels(commands.Cog):          await self.check_for_answer(message) -        if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): +        is_available = channel_utils.is_in_category(channel, constants.Categories.help_available) +        if not is_available or self.is_excluded_channel(channel):              return  # Ignore messages outside the Available category or in excluded channels.          log.trace("Waiting for the cog to be ready before processing messages.") @@ -681,7 +684,7 @@ class HelpChannels(commands.Cog):          async with self.on_message_lock:              log.trace(f"on_message lock acquired for {message.id}.") -            if not self.is_in_category(channel, constants.Categories.help_available): +            if not channel_utils.is_in_category(channel, constants.Categories.help_available):                  log.debug(                      f"Message {message.id} will not make #{channel} ({channel.id}) in-use "                      f"because another message in the channel already triggered that." @@ -719,7 +722,7 @@ class HelpChannels(commands.Cog):          The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`.          """ -        if not self.is_in_category(msg.channel, constants.Categories.help_in_use): +        if not channel_utils.is_in_category(msg.channel, constants.Categories.help_in_use):              return          if not await self.is_empty(msg.channel): @@ -844,18 +847,6 @@ class HelpChannels(commands.Cog):              log.trace(f"Dormant message not found in {channel_info}; sending a new message.")              await channel.send(embed=embed) -    async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: -        """Attempt to get or fetch a channel and return it.""" -        log.trace(f"Getting the channel {channel_id}.") - -        channel = self.bot.get_channel(channel_id) -        if not channel: -            log.debug(f"Channel {channel_id} is not in cache; fetching from API.") -            channel = await self.bot.fetch_channel(channel_id) - -        log.trace(f"Channel #{channel} ({channel_id}) retrieved.") -        return channel -      async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool:          """          Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. diff --git a/bot/exts/info/codeblock/__init__.py b/bot/exts/info/codeblock/__init__.py new file mode 100644 index 000000000..5c55bc5e3 --- /dev/null +++ b/bot/exts/info/codeblock/__init__.py @@ -0,0 +1,8 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: +    """Load the CodeBlockCog cog.""" +    # Defer import to reduce side effects from importing the codeblock package. +    from bot.exts.info.codeblock._cog import CodeBlockCog +    bot.add_cog(CodeBlockCog(bot)) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py new file mode 100644 index 000000000..1e0feab0d --- /dev/null +++ b/bot/exts/info/codeblock/_cog.py @@ -0,0 +1,186 @@ +import logging +import time +from typing import Optional + +import discord +from discord import Message, RawMessageUpdateEvent +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.exts.filters.token_remover import TokenRemover +from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE +from bot.exts.info.codeblock._instructions import get_instructions +from bot.utils import has_lines +from bot.utils.channel import is_help_channel +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + + +class CodeBlockCog(Cog, name="Code Block"): +    """ +    Detect improperly formatted Markdown code blocks and suggest proper formatting. + +    There are four basic ways in which a code block is considered improperly formatted: + +    1. The code is not within a code block at all +        * Ignored if the code is not valid Python or Python REPL code +    2. Incorrect characters are used for backticks +    3. A language for syntax highlighting is not specified +        * Ignored if the code is not valid Python or Python REPL code +    4. A syntax highlighting language is incorrectly specified +        * Ignored if the language specified doesn't look like it was meant for Python +        * This can go wrong in two ways: +            1. Spaces before the language +            2. No newline immediately following the language + +    Messages or code blocks must meet a minimum line count to be detected. Detecting multiple code +    blocks is supported. However, if at least one code block is correct, then instructions will not +    be sent even if others are incorrect. When multiple incorrect code blocks are found, only the +    first one is used as the basis for the instructions sent. + +    When an issue is detected, an embed is sent containing specific instructions on fixing what +    is wrong. If the user edits their message to fix the code block, the instructions will be +    removed. If they fail to fix the code block with an edit, the instructions will be updated to +    show what is still incorrect after the user's edit. The embed can be manually deleted with a +    reaction. Otherwise, it will automatically be removed after 5 minutes. + +    The cog only detects messages in whitelisted channels. Channels may also have a cooldown on the +    instructions being sent. Note all help channels are also whitelisted with cooldowns enabled. + +    For configurable parameters, see the `code_block` section in config-default.py. +    """ + +    def __init__(self, bot: Bot): +        self.bot = bot + +        # Stores allowed channels plus epoch times since the last instructional messages sent. +        self.channel_cooldowns = dict.fromkeys(constants.CodeBlock.cooldown_channels, 0.0) + +        # Maps users' messages to the messages the bot sent with instructions. +        self.codeblock_message_ids = {} + +    @staticmethod +    def create_embed(instructions: str) -> discord.Embed: +        """Return an embed which displays code block formatting `instructions`.""" +        return discord.Embed(description=instructions) + +    async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]: +        """ +        Return the bot's sent instructions message associated with a user's message `payload`. + +        Return None if the message cannot be found. In this case, it's likely the message was +        deleted either manually via a reaction or automatically by a timer. +        """ +        log.trace(f"Retrieving instructions message for ID {payload.message_id}") +        channel = self.bot.get_channel(payload.channel_id) + +        try: +            return await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) +        except discord.NotFound: +            log.debug("Could not find instructions message; it was probably deleted.") +            return None + +    def is_on_cooldown(self, channel: discord.TextChannel) -> bool: +        """ +        Return True if an embed was sent too recently for `channel`. + +        The cooldown is configured by `constants.CodeBlock.cooldown_seconds`. +        Note: only channels in the `channel_cooldowns` have cooldowns enabled. +        """ +        log.trace(f"Checking if #{channel} is on cooldown.") +        cooldown = constants.CodeBlock.cooldown_seconds +        return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown + +    def is_valid_channel(self, channel: discord.TextChannel) -> bool: +        """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted.""" +        log.trace(f"Checking if #{channel} qualifies for code block detection.") +        return ( +            is_help_channel(channel) +            or channel.id in self.channel_cooldowns +            or channel.id in constants.CodeBlock.channel_whitelist +        ) + +    async def send_instructions(self, message: discord.Message, instructions: str) -> None: +        """ +        Send an embed with `instructions` on fixing an incorrect code block in a `message`. + +        The embed will be deleted automatically after 5 minutes. +        """ +        log.info(f"Sending code block formatting instructions for message {message.id}.") + +        embed = self.create_embed(instructions) +        bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed) +        self.codeblock_message_ids[message.id] = bot_message.id + +        self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot)) + +        # Increase amount of codeblock correction in stats +        self.bot.stats.incr("codeblock_corrections") + +    def should_parse(self, message: discord.Message) -> bool: +        """ +        Return True if `message` should be parsed. + +        A qualifying message: + +        1. Is not authored by a bot +        2. Is in a valid channel +        3. Has more than 3 lines +        4. Has no bot or webhook token +        """ +        return ( +            not message.author.bot +            and self.is_valid_channel(message.channel) +            and has_lines(message.content, constants.CodeBlock.minimum_lines) +            and not TokenRemover.find_token_in_message(message) +            and not WEBHOOK_URL_RE.search(message.content) +        ) + +    @Cog.listener() +    async def on_message(self, msg: Message) -> None: +        """Detect incorrect Markdown code blocks in `msg` and send instructions to fix them.""" +        if not self.should_parse(msg): +            log.trace(f"Skipping code block detection of {msg.id}: message doesn't qualify.") +            return + +        # When debugging, ignore cooldowns. +        if self.is_on_cooldown(msg.channel) and not constants.DEBUG_MODE: +            log.trace(f"Skipping code block detection of {msg.id}: #{msg.channel} is on cooldown.") +            return + +        instructions = get_instructions(msg.content) +        if instructions: +            await self.send_instructions(msg, instructions) + +            if msg.channel.id not in constants.CodeBlock.channel_whitelist: +                log.debug(f"Adding #{msg.channel} to the channel cooldowns.") +                self.channel_cooldowns[msg.channel.id] = time.time() + +    @Cog.listener() +    async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: +        """Delete the instructional message if an edited message had its code blocks fixed.""" +        if payload.message_id not in self.codeblock_message_ids: +            log.trace(f"Ignoring message edit {payload.message_id}: message isn't being tracked.") +            return + +        if payload.data.get("content") is None or payload.data.get("channel_id") is None: +            log.trace(f"Ignoring message edit {payload.message_id}: missing content or channel ID.") +            return + +        # Parse the message to see if the code blocks have been fixed. +        content = payload.data.get("content") +        instructions = get_instructions(content) + +        bot_message = await self.get_sent_instructions(payload) +        if not bot_message: +            return + +        if not instructions: +            log.info("User's incorrect code block has been fixed. Removing instructions message.") +            await bot_message.delete() +            del self.codeblock_message_ids[payload.message_id] +        else: +            log.info("Message edited but still has invalid code blocks; editing the instructions.") +            await bot_message.edit(embed=self.create_embed(instructions)) diff --git a/bot/exts/info/codeblock/_instructions.py b/bot/exts/info/codeblock/_instructions.py new file mode 100644 index 000000000..508f157fb --- /dev/null +++ b/bot/exts/info/codeblock/_instructions.py @@ -0,0 +1,184 @@ +"""This module generates and formats instructional messages about fixing Markdown code blocks.""" + +import logging +from typing import Optional + +from bot.exts.info.codeblock import _parsing + +log = logging.getLogger(__name__) + +_EXAMPLE_PY = "{lang}\nprint('Hello, world!')"  # Make sure to escape any Markdown symbols here. +_EXAMPLE_CODE_BLOCKS = ( +    "\\`\\`\\`{content}\n\\`\\`\\`\n\n" +    "**This will result in the following:**\n" +    "```{content}```" +) + + +def _get_example(language: str) -> str: +    """Return an example of a correct code block using `language` for syntax highlighting.""" +    # Determine the example code to put in the code block based on the language specifier. +    if language.lower() in _parsing.PY_LANG_CODES: +        log.trace(f"Code block has a Python language specifier `{language}`.") +        content = _EXAMPLE_PY.format(lang=language) +    elif language: +        log.trace(f"Code block has a foreign language specifier `{language}`.") +        # It's not feasible to determine what would be a valid example for other languages. +        content = f"{language}\n..." +    else: +        log.trace("Code block has no language specifier.") +        content = "\nHello, world!" + +    return _EXAMPLE_CODE_BLOCKS.format(content=content) + + +def _get_bad_ticks_message(code_block: _parsing.CodeBlock) -> Optional[str]: +    """Return instructions on using the correct ticks for `code_block`.""" +    log.trace("Creating instructions for incorrect code block ticks.") + +    valid_ticks = f"\\{_parsing.BACKTICK}" * 3 +    instructions = ( +        "It looks like you are trying to paste code into this channel.\n\n" +        "You seem to be using the wrong symbols to indicate where the code block should start. " +        f"The correct symbols would be {valid_ticks}, not `{code_block.tick * 3}`." +    ) + +    log.trace("Check if the bad ticks code block also has issues with the language specifier.") +    addition_msg = _get_bad_lang_message(code_block.content) +    if not addition_msg and not code_block.language: +        addition_msg = _get_no_lang_message(code_block.content) + +    # Combine the back ticks message with the language specifier message. The latter will +    # already have an example code block. +    if addition_msg: +        log.trace("Language specifier issue found; appending additional instructions.") + +        # The first line has double newlines which are not desirable when appending the msg. +        addition_msg = addition_msg.replace("\n\n", " ", 1) + +        # Make the first character of the addition lower case. +        instructions += "\n\nFurthermore, " + addition_msg[0].lower() + addition_msg[1:] +    else: +        log.trace("No issues with the language specifier found.") +        example_blocks = _get_example(code_block.language) +        instructions += f"\n\n**Here is an example of how it should look:**\n{example_blocks}" + +    return instructions + + +def _get_no_ticks_message(content: str) -> Optional[str]: +    """If `content` is Python/REPL code, return instructions on using code blocks.""" +    log.trace("Creating instructions for a missing code block.") + +    if _parsing.is_python_code(content): +        example_blocks = _get_example("python") +        return ( +            "It looks like you're trying to paste code into this channel.\n\n" +            "Discord has support for Markdown, which allows you to post code with full " +            "syntax highlighting. Please use these whenever you paste code, as this " +            "helps improve the legibility and makes it easier for us to help you.\n\n" +            f"**To do this, use the following method:**\n{example_blocks}" +        ) +    else: +        log.trace("Aborting missing code block instructions: content is not Python code.") + + +def _get_bad_lang_message(content: str) -> Optional[str]: +    """ +    Return instructions on fixing the Python language specifier for a code block. + +    If `code_block` does not have a Python language specifier, return None. +    If there's nothing wrong with the language specifier, return None. +    """ +    log.trace("Creating instructions for a poorly specified language.") + +    info = _parsing.parse_bad_language(content) +    if not info: +        log.trace("Aborting bad language instructions: language specified isn't Python.") +        return + +    lines = [] +    language = info.language + +    if info.has_leading_spaces: +        log.trace("Language specifier was preceded by a space.") +        lines.append(f"Make sure there are no spaces between the back ticks and `{language}`.") + +    if not info.has_terminal_newline: +        log.trace("Language specifier was not followed by a newline.") +        lines.append( +            f"Make sure you put your code on a new line following `{language}`. " +            f"There must not be any spaces after `{language}`." +        ) + +    if lines: +        lines = " ".join(lines) +        example_blocks = _get_example(language) + +        # Note that _get_bad_ticks_message expects the first line to have two newlines. +        return ( +            f"It looks like you incorrectly specified a language for your code block.\n\n{lines}" +            f"\n\n**Here is an example of how it should look:**\n{example_blocks}" +        ) +    else: +        log.trace("Nothing wrong with the language specifier; no instructions to return.") + + +def _get_no_lang_message(content: str) -> Optional[str]: +    """ +    Return instructions on specifying a language for a code block. + +    If `content` is not valid Python or Python REPL code, return None. +    """ +    log.trace("Creating instructions for a missing language.") + +    if _parsing.is_python_code(content): +        example_blocks = _get_example("python") + +        # Note that _get_bad_ticks_message expects the first line to have two newlines. +        return ( +            "It looks like you pasted Python code without syntax highlighting.\n\n" +            "Please use syntax highlighting to improve the legibility of your code and make " +            "it easier for us to help you.\n\n" +            f"**To do this, use the following method:**\n{example_blocks}" +        ) +    else: +        log.trace("Aborting missing language instructions: content is not Python code.") + + +def get_instructions(content: str) -> Optional[str]: +    """ +    Parse `content` and return code block formatting instructions if something is wrong. + +    Return None if `content` lacks code block formatting issues. +    """ +    log.trace("Getting formatting instructions.") + +    blocks = _parsing.find_code_blocks(content) +    if blocks is None: +        log.trace("At least one valid code block found; no instructions to return.") +        return + +    if not blocks: +        log.trace("No code blocks were found in message.") +        instructions = _get_no_ticks_message(content) +    else: +        log.trace("Searching results for a code block with invalid ticks.") +        block = next((block for block in blocks if block.tick != _parsing.BACKTICK), None) + +        if block: +            log.trace("A code block exists but has invalid ticks.") +            instructions = _get_bad_ticks_message(block) +        else: +            log.trace("A code block exists but is missing a language.") +            block = blocks[0] + +            # Check for a bad language first to avoid parsing content into an AST. +            instructions = _get_bad_lang_message(block.content) +            if not instructions: +                instructions = _get_no_lang_message(block.content) + +    if instructions: +        instructions += "\nYou can **edit your original message** to correct your code block." + +    return instructions diff --git a/bot/exts/info/codeblock/_parsing.py b/bot/exts/info/codeblock/_parsing.py new file mode 100644 index 000000000..a98218dfb --- /dev/null +++ b/bot/exts/info/codeblock/_parsing.py @@ -0,0 +1,228 @@ +"""This module provides functions for parsing Markdown code blocks.""" + +import ast +import logging +import re +import textwrap +from typing import NamedTuple, Optional, Sequence + +from bot import constants +from bot.utils import has_lines + +log = logging.getLogger(__name__) + +BACKTICK = "`" +PY_LANG_CODES = ("python", "pycon", "py")  # Order is important; "py" is last cause it's a subset. +_TICKS = { +    BACKTICK, +    "'", +    '"', +    "\u00b4",  # ACUTE ACCENT +    "\u2018",  # LEFT SINGLE QUOTATION MARK +    "\u2019",  # RIGHT SINGLE QUOTATION MARK +    "\u2032",  # PRIME +    "\u201c",  # LEFT DOUBLE QUOTATION MARK +    "\u201d",  # RIGHT DOUBLE QUOTATION MARK +    "\u2033",  # DOUBLE PRIME +    "\u3003",  # VERTICAL KANA REPEAT MARK UPPER HALF +} + +_RE_PYTHON_REPL = re.compile(r"^(>>>|\.\.\.)( |$)") +_RE_IPYTHON_REPL = re.compile(r"^((In|Out) \[\d+\]: |\s*\.{3,}: ?)") + +_RE_CODE_BLOCK = re.compile( +    fr""" +    (?P<ticks> +        (?P<tick>[{''.join(_TICKS)}]) # Put all ticks into a character class within a group. +        \2{{2}}                       # Match previous group 2 more times to ensure the same char. +    ) +    (?P<lang>[^\W_]+\n)?              # Optionally match a language specifier followed by a newline. +    (?P<code>.+?)                     # Match the actual code within the block. +    \1                                # Match the same 3 ticks used at the start of the block. +    """, +    re.DOTALL | re.VERBOSE +) + +_RE_LANGUAGE = re.compile( +    fr""" +    ^(?P<spaces>\s+)?                    # Optionally match leading spaces from the beginning. +    (?P<lang>{'|'.join(PY_LANG_CODES)})  # Match a Python language. +    (?P<newline>\n)?                     # Optionally match a newline following the language. +    """, +    re.IGNORECASE | re.VERBOSE +) + + +class CodeBlock(NamedTuple): +    """Represents a Markdown code block.""" + +    content: str +    language: str +    tick: str + + +class BadLanguage(NamedTuple): +    """Parsed information about a poorly formatted language specifier.""" + +    language: str +    has_leading_spaces: bool +    has_terminal_newline: bool + + +def find_code_blocks(message: str) -> Optional[Sequence[CodeBlock]]: +    """ +    Find and return all Markdown code blocks in the `message`. + +    Code blocks with 3 or fewer lines are excluded. + +    If the `message` contains at least one code block with valid ticks and a specified language, +    return None. This is based on the assumption that if the user managed to get one code block +    right, they already know how to fix the rest themselves. +    """ +    log.trace("Finding all code blocks in a message.") + +    code_blocks = [] +    for match in _RE_CODE_BLOCK.finditer(message): +        # Used to ensure non-matched groups have an empty string as the default value. +        groups = match.groupdict("") +        language = groups["lang"].strip()  # Strip the newline cause it's included in the group. + +        if groups["tick"] == BACKTICK and language: +            log.trace("Message has a valid code block with a language; returning None.") +            return None +        elif has_lines(groups["code"], constants.CodeBlock.minimum_lines): +            code_block = CodeBlock(groups["code"], language, groups["tick"]) +            code_blocks.append(code_block) +        else: +            log.trace("Skipped a code block shorter than 4 lines.") + +    return code_blocks + + +def _is_python_code(content: str) -> bool: +    """Return True if `content` is valid Python consisting of more than just expressions.""" +    log.trace("Checking if content is Python code.") +    try: +        # Attempt to parse the message into an AST node. +        # Invalid Python code will raise a SyntaxError. +        tree = ast.parse(content) +    except SyntaxError: +        log.trace("Code is not valid Python.") +        return False + +    # Multiple lines of single words could be interpreted as expressions. +    # This check is to avoid all nodes being parsed as expressions. +    # (e.g. words over multiple lines) +    if not all(isinstance(node, ast.Expr) for node in tree.body): +        log.trace("Code is valid python.") +        return True +    else: +        log.trace("Code consists only of expressions.") +        return False + + +def _is_repl_code(content: str, threshold: int = 3) -> bool: +    """Return True if `content` has at least `threshold` number of (I)Python REPL-like lines.""" +    log.trace(f"Checking if content is (I)Python REPL code using a threshold of {threshold}.") + +    repl_lines = 0 +    patterns = (_RE_PYTHON_REPL, _RE_IPYTHON_REPL) + +    for line in content.splitlines(): +        # Check the line against all patterns. +        for pattern in patterns: +            if pattern.match(line): +                repl_lines += 1 + +                # Once a pattern is matched, only use that pattern for the remaining lines. +                patterns = (pattern,) +                break + +        if repl_lines == threshold: +            log.trace("Content is (I)Python REPL code.") +            return True + +    log.trace("Content is not (I)Python REPL code.") +    return False + + +def is_python_code(content: str) -> bool: +    """Return True if `content` is valid Python code or (I)Python REPL output.""" +    dedented = textwrap.dedent(content) + +    # Parse AST twice in case _fix_indentation ends up breaking code due to its inaccuracies. +    return ( +        _is_python_code(dedented) +        or _is_repl_code(dedented) +        or _is_python_code(_fix_indentation(content)) +    ) + + +def parse_bad_language(content: str) -> Optional[BadLanguage]: +    """ +    Return information about a poorly formatted Python language in code block `content`. + +    If the language is not Python, return None. +    """ +    log.trace("Parsing bad language.") + +    match = _RE_LANGUAGE.match(content) +    if not match: +        return None + +    return BadLanguage( +        language=match["lang"], +        has_leading_spaces=match["spaces"] is not None, +        has_terminal_newline=match["newline"] is not None, +    ) + + +def _get_leading_spaces(content: str) -> int: +    """Return the number of spaces at the start of the first line in `content`.""" +    leading_spaces = 0 +    for char in content: +        if char == " ": +            leading_spaces += 1 +        else: +            return leading_spaces + + +def _fix_indentation(content: str) -> str: +    """ +    Attempt to fix badly indented code in `content`. + +    In most cases, this works like textwrap.dedent. However, if the first line ends with a colon, +    all subsequent lines are re-indented to only be one level deep relative to the first line. +    The intent is to fix cases where the leading spaces of the first line of code were accidentally +    not copied, which makes the first line appear not indented. + +    This is fairly naïve and inaccurate. Therefore, it may break some code that was otherwise valid. +    It's meant to catch really common cases, so that's acceptable. Its flaws are: + +    - It assumes that if the first line ends with a colon, it is the start of an indented block +    - It uses 4 spaces as the indentation, regardless of what the rest of the code uses +    """ +    lines = content.splitlines(keepends=True) + +    # Dedent the first line +    first_indent = _get_leading_spaces(content) +    first_line = lines[0][first_indent:] + +    # Can't assume there'll be multiple lines cause line counts of edited messages aren't checked. +    if len(lines) == 1: +        return first_line + +    second_indent = _get_leading_spaces(lines[1]) + +    # If the first line ends with a colon, all successive lines need to be indented one +    # additional level (assumes an indent width of 4). +    if first_line.rstrip().endswith(":"): +        second_indent -= 4 + +    # All lines must be dedented at least by the same amount as the first line. +    first_indent = max(first_indent, second_indent) + +    # Dedent the rest of the lines and join them together with the first line. +    content = first_line + "".join(line[first_indent:] for line in lines[1:]) + +    return content diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 52239c19e..5aaf85e5a 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,15 +6,16 @@ from collections import Counter, defaultdict  from string import Template  from typing import Any, Mapping, Optional, Tuple, Union -from discord import ChannelType, Colour, CustomActivity, Embed, Guild, Member, Message, Role, Status, utils +from discord import ChannelType, Colour, Embed, Guild, Message, Role, Status, utils  from discord.abc import GuildChannel  from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role -from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot +from bot.converters import FetchedMember  from bot.decorators import in_whitelist  from bot.pagination import LinePaginator +from bot.utils.channel import is_mod_channel  from bot.utils.checks import cooldown_with_role_bypass, has_no_roles_check, in_whitelist_check  from bot.utils.time import time_since @@ -153,7 +154,9 @@ class Information(Cog):          channel_counts = self.get_channel_type_counts(ctx.guild)          # How many of each user status? -        statuses = Counter(member.status for member in ctx.guild.members) +        py_invite = await self.bot.fetch_invite(constants.Guild.invite) +        online_presences = py_invite.approximate_presence_count +        offline_presences = py_invite.approximate_member_count - online_presences          embed = Embed(colour=Colour.blurple())          # How many staff members and staff channels do we have? @@ -181,10 +184,8 @@ class Information(Cog):                  Roles: {roles}                  **Member statuses** -                {constants.Emojis.status_online} {statuses[Status.online]:,} -                {constants.Emojis.status_idle} {statuses[Status.idle]:,} -                {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} -                {constants.Emojis.status_offline} {statuses[Status.offline]:,} +                {constants.Emojis.status_online} {online_presences:,} +                {constants.Emojis.status_offline} {offline_presences:,}              """)          ).substitute({"channel_counts": channel_counts})          embed.set_thumbnail(url=ctx.guild.icon_url) @@ -192,7 +193,7 @@ class Information(Cog):          await ctx.send(embed=embed)      @command(name="user", aliases=["user_info", "member", "member_info"]) -    async def user_info(self, ctx: Context, user: Member = None) -> None: +    async def user_info(self, ctx: Context, user: FetchedMember = None) -> None:          """Returns info about a user."""          if user is None:              user = ctx.author @@ -207,31 +208,14 @@ class Information(Cog):              embed = await self.create_user_embed(ctx, user)              await ctx.send(embed=embed) -    async def create_user_embed(self, ctx: Context, user: Member) -> Embed: +    async def create_user_embed(self, ctx: Context, user: FetchedMember) -> Embed:          """Creates an embed containing information on the `user`.""" -        created = time_since(user.created_at, max_units=3) - -        # Custom status -        custom_status = '' -        for activity in user.activities: -            if isinstance(activity, CustomActivity): -                state = "" - -                if activity.name: -                    state = escape_markdown(activity.name) - -                emoji = "" -                if activity.emoji: -                    # If an emoji is unicode use the emoji, else write the emote like :abc: -                    if not activity.emoji.id: -                        emoji += activity.emoji.name + " " -                    else: -                        emoji += f"`:{activity.emoji.name}:` " +        on_server = bool(ctx.guild.get_member(user.id)) -                custom_status = f'Status: {emoji}{state}\n' +        created = time_since(user.created_at, max_units=3)          name = str(user) -        if user.nick: +        if on_server and user.nick:              name = f"{user.nick} ({name})"          badges = [] @@ -240,12 +224,16 @@ class Information(Cog):              if is_set and (emoji := getattr(constants.Emojis, f"badge_{badge}", None)):                  badges.append(emoji) -        joined = time_since(user.joined_at, max_units=3) -        roles = ", ".join(role.mention for role in user.roles[1:]) - -        desktop_status = STATUS_EMOTES.get(user.desktop_status, constants.Emojis.status_online) -        web_status = STATUS_EMOTES.get(user.web_status, constants.Emojis.status_online) -        mobile_status = STATUS_EMOTES.get(user.mobile_status, constants.Emojis.status_online) +        if on_server: +            joined = time_since(user.joined_at, max_units=3) +            roles = ", ".join(role.mention for role in user.roles[1:]) +            membership = textwrap.dedent(f""" +                             Joined: {joined} +                             Roles: {roles or None} +                         """).strip() +        else: +            roles = None +            membership = "The user is not a member of the server"          fields = [              ( @@ -254,34 +242,16 @@ class Information(Cog):                      Created: {created}                      Profile: {user.mention}                      ID: {user.id} -                    {custom_status}                  """).strip()              ),              (                  "Member information", -                textwrap.dedent(f""" -                    Joined: {joined} -                    Roles: {roles or None} -                """).strip() +                membership              ), -            ( -                "Status", -                textwrap.dedent(f""" -                    {desktop_status} Desktop -                    {web_status} Web -                    {mobile_status} Mobile -                """).strip() -            )          ] -        # Use getattr to future-proof for commands invoked via DMs. -        show_verbose = ( -            ctx.channel.id in constants.MODERATION_CHANNELS -            or getattr(ctx.channel, "category_id", None) == constants.Categories.modmail -        ) -          # Show more verbose output in moderation channels for infractions and nominations -        if show_verbose: +        if is_mod_channel(ctx.channel):              fields.append(await self.expanded_user_infraction_counts(user))              fields.append(await self.user_nomination_counts(user))          else: @@ -301,13 +271,13 @@ class Information(Cog):          return embed -    async def basic_user_infraction_counts(self, member: Member) -> Tuple[str, str]: +    async def basic_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:          """Gets the total and active infraction counts for the given `member`."""          infractions = await self.bot.api_client.get(              'bot/infractions',              params={                  'hidden': 'False', -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) @@ -318,7 +288,7 @@ class Information(Cog):          return "Infractions", infraction_output -    async def expanded_user_infraction_counts(self, member: Member) -> Tuple[str, str]: +    async def expanded_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:          """          Gets expanded infraction counts for the given `member`. @@ -328,7 +298,7 @@ class Information(Cog):          infractions = await self.bot.api_client.get(              'bot/infractions',              params={ -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) @@ -359,12 +329,12 @@ class Information(Cog):          return "Infractions", "\n".join(infraction_output) -    async def user_nomination_counts(self, member: Member) -> Tuple[str, str]: +    async def user_nomination_counts(self, user: FetchedMember) -> Tuple[str, str]:          """Gets the active and historical nomination counts for the given `member`."""          nominations = await self.bot.api_client.get(              'bot/nominations',              params={ -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py index debe40c82..bad4c504d 100644 --- a/bot/exts/info/reddit.py +++ b/bot/exts/info/reddit.py @@ -140,7 +140,10 @@ class Reddit(Cog):                  # Got appropriate response - process and return.                  content = await response.json()                  posts = content["data"]["children"] -                return posts[:amount] + +                filtered_posts = [post for post in posts if not post["data"]["over_18"]] + +                return filtered_posts[:amount]              await asyncio.sleep(3) @@ -163,12 +166,11 @@ class Reddit(Cog):              amount=amount,              params={"t": time}          ) -          if not posts:              embed.title = random.choice(ERROR_REPLIES)              embed.colour = Colour.red()              embed.description = ( -                "Sorry! We couldn't find any posts from that subreddit. " +                "Sorry! We couldn't find any SFW posts from that subreddit. "                  "If this problem persists, please let us know."              ) diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py index 2d3a3d9f3..fb5b99086 100644 --- a/bot/exts/info/site.py +++ b/bot/exts/info/site.py @@ -1,7 +1,7 @@  import logging  from discord import Colour, Embed -from discord.ext.commands import Cog, Context, group +from discord.ext.commands import Cog, Context, Greedy, group  from bot.bot import Bot  from bot.constants import URLs @@ -105,10 +105,9 @@ class Site(Cog):          await ctx.send(embed=embed)      @site_group.command(name="rules", aliases=("r", "rule"), root_aliases=("rules", "rule")) -    async def site_rules(self, ctx: Context, *rules: int) -> None: +    async def site_rules(self, ctx: Context, rules: Greedy[int]) -> None:          """Provides a link to all rules or, if specified, displays specific rule(s).""" -        rules_embed = Embed(title='Rules', color=Colour.blurple()) -        rules_embed.url = f"{PAGES_URL}/rules" +        rules_embed = Embed(title='Rules', color=Colour.blurple(), url=f'{PAGES_URL}/rules')          if not rules:              # Rules were not submitted. Return the default description. @@ -122,15 +121,13 @@ class Site(Cog):              return          full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'}) -        invalid_indices = tuple( -            pick -            for pick in rules -            if pick < 1 or pick > len(full_rules) -        ) -        if invalid_indices: -            indices = ', '.join(map(str, invalid_indices)) -            await ctx.send(f":x: Invalid rule indices: {indices}") +        # Remove duplicates and sort the rule indices +        rules = sorted(set(rules)) +        invalid = ', '.join(str(index) for index in rules if index < 1 or index > len(full_rules)) + +        if invalid: +            await ctx.send(f":x: Invalid rule indices: {invalid}")              return          for rule in rules: diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py index d42f55466..4d8bb645e 100644 --- a/bot/exts/info/stats.py +++ b/bot/exts/info/stats.py @@ -1,13 +1,12 @@  import string -from datetime import datetime -from discord import Member, Message, Status +from discord import Member, Message  from discord.ext.commands import Cog, Context  from discord.ext.tasks import loop  from bot.bot import Bot -from bot.constants import Categories, Channels, Guild, Stats as StatConf - +from bot.constants import Categories, Channels, Guild +from bot.utils.channel import is_in_category  CHANNEL_NAME_OVERRIDES = {      Channels.off_topic_0: "off_topic_0", @@ -36,8 +35,7 @@ class Stats(Cog):          if message.guild.id != Guild.id:              return -        cat = getattr(message.channel, "category", None) -        if cat is not None and cat.id == Categories.modmail: +        if is_in_category(message.channel, Categories.modmail):              if message.channel.id != Channels.incidents:                  # Do not report modmail channels to stats, there are too many                  # of them for interesting statistics to be drawn out of this. @@ -79,38 +77,6 @@ class Stats(Cog):          self.bot.stats.gauge("guild.total_members", len(member.guild.members)) -    @Cog.listener() -    async def on_member_update(self, _before: Member, after: Member) -> None: -        """Update presence estimates on member update.""" -        if after.guild.id != Guild.id: -            return - -        if self.last_presence_update: -            if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout: -                return - -        self.last_presence_update = datetime.now() - -        online = 0 -        idle = 0 -        dnd = 0 -        offline = 0 - -        for member in after.guild.members: -            if member.status is Status.online: -                online += 1 -            elif member.status is Status.dnd: -                dnd += 1 -            elif member.status is Status.idle: -                idle += 1 -            elif member.status is Status.offline: -                offline += 1 - -        self.bot.stats.gauge("guild.status.online", online) -        self.bot.stats.gauge("guild.status.idle", idle) -        self.bot.stats.gauge("guild.status.do_not_disturb", dnd) -        self.bot.stats.gauge("guild.status.offline", offline) -      @loop(hours=1)      async def update_guild_boost(self) -> None:          """Post the server boost level and tier every hour.""" diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py index 14263e004..4d5142b55 100644 --- a/bot/exts/moderation/dm_relay.py +++ b/bot/exts/moderation/dm_relay.py @@ -90,7 +90,11 @@ class DMRelay(Cog):          # Handle any attachments          if message.attachments:              try: -                await send_attachments(message, self.webhook) +                await send_attachments( +                    message, +                    self.webhook, +                    username=f"{message.author.display_name} ({message.author.id})" +                )              except (discord.errors.Forbidden, discord.errors.NotFound):                  e = discord.Embed(                      description=":x: **This message contained an attachment, but it could not be retrieved**", diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 814b17830..bebade0ae 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -12,11 +12,12 @@ from discord.ext.commands import Context  from bot import constants  from bot.api import ResponseCodeError  from bot.bot import Bot -from bot.constants import Colours, MODERATION_CHANNELS +from bot.constants import Colours  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._utils import UserSnowflake  from bot.exts.moderation.modlog import ModLog  from bot.utils import messages, scheduling, time +from bot.utils.channel import is_mod_channel  log = logging.getLogger(__name__) @@ -125,7 +126,7 @@ class InfractionScheduler:                  log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})")              else:                  # Accordingly display whether the user was successfully notified via DM. -                if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): +                if await _utils.notify_infraction(user, " ".join(infr_type.split("_")).title(), expiry, reason, icon):                      dm_result = ":incoming_envelope: "                      dm_log_text = "\nDM: Sent" @@ -136,11 +137,7 @@ class InfractionScheduler:              )              if reason:                  end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" -        elif ctx.channel.id not in MODERATION_CHANNELS: -            log.trace( -                f"Infraction #{id_} context is not in a mod channel; omitting infraction count." -            ) -        else: +        elif is_mod_channel(ctx.channel):              log.trace(f"Fetching total infraction count for {user}.")              infractions = await self.bot.api_client.get( @@ -148,7 +145,7 @@ class InfractionScheduler:                  params={"user__id": str(user.id)}              )              total = len(infractions) -            end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" +            end_msg = f" (#{id_} ; {total} infraction{ngettext('', 's', total)} total)"          # Execute the necessary actions to apply the infraction on Discord.          if action_coro: @@ -166,7 +163,7 @@ class InfractionScheduler:                  log_content = ctx.author.mention                  log_title = "failed to apply" -                log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" +                log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}"                  if isinstance(e, discord.Forbidden):                      log.warning(f"{log_msg}: bot lacks permissions.")                  else: @@ -183,7 +180,7 @@ class InfractionScheduler:                  log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.")              infr_message = ""          else: -            infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" +            infr_message = f" **{' '.join(infr_type.split('_'))}** to {user.mention}{expiry_msg}{end_msg}"          # Send a confirmation message to the invoking context.          log.trace(f"Sending infraction #{id_} confirmation message.") @@ -195,7 +192,7 @@ class InfractionScheduler:          await self.mod_log.send_log_message(              icon_url=icon,              colour=Colours.soft_red, -            title=f"Infraction {log_title}: {infr_type}", +            title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",              thumbnail=user.avatar_url_as(static_format="png"),              text=textwrap.dedent(f"""                  Member: {messages.format_user(user)} @@ -272,7 +269,7 @@ class InfractionScheduler:          if send_msg:              log.trace(f"Sending infraction #{id_} pardon confirmation message.")              await ctx.send( -                f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " +                f"{dm_emoji}{confirm_msg} infraction **{' '.join(infr_type.split('_'))}** for {user.mention}. "                  f"{log_text.get('Failure', '')}"              ) @@ -283,7 +280,7 @@ class InfractionScheduler:          await self.mod_log.send_log_message(              icon_url=_utils.INFRACTION_ICONS[infr_type][1],              colour=Colours.soft_green, -            title=f"Infraction {log_title}: {infr_type}", +            title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",              thumbnail=user.avatar_url_as(static_format="png"),              text="\n".join(f"{k}: {v}" for k, v in log_text.items()),              footer=footer, diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 1d91964f1..d0dc3f0a1 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -18,9 +18,10 @@ INFRACTION_ICONS = {      "note": (Icons.user_warn, None),      "superstar": (Icons.superstarify, Icons.unsuperstarify),      "warning": (Icons.user_warn, None), +    "voice_ban": (Icons.voice_state_red, Icons.voice_state_green),  }  RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") +APPEALABLE_INFRACTIONS = ("ban", "mute", "voice_ban")  # Type aliases  UserObject = t.Union[discord.Member, discord.User] @@ -154,7 +155,7 @@ async def notify_infraction(      log.trace(f"Sending {user} a DM about their {infr_type} infraction.")      text = INFRACTION_DESCRIPTION_TEMPLATE.format( -        type=infr_type.capitalize(), +        type=infr_type.title(),          expires=expires_at or "N/A",          reason=reason or "No reason provided."      ) diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index a8b3feb38..746d4e154 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -31,6 +31,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self.category = "Moderation"          self._muted_role = discord.Object(constants.Roles.muted) +        self._voice_verified_role = discord.Object(constants.Roles.voice_verified)      @commands.Cog.listener()      async def on_member_join(self, member: Member) -> None: @@ -71,6 +72,28 @@ class Infractions(InfractionScheduler, commands.Cog):          """Permanently ban a user for the given reason and stop watching them with Big Brother."""          await self.apply_ban(ctx, user, reason) +    @command(aliases=('pban',)) +    async def purgeban( +        self, +        ctx: Context, +        user: FetchedMember, +        purge_days: t.Optional[int] = 1, +        *, +        reason: t.Optional[str] = None +    ) -> None: +        """ +        Same as ban but removes all their messages for the given number of days, default being 1. + +        `purge_days` can only be values between 0 and 7. +        Anything outside these bounds are automatically adjusted to their respective limits. +        """ +        await self.apply_ban(ctx, user, reason, max(min(purge_days, 7), 0)) + +    @command(aliases=('vban',)) +    async def voiceban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str]) -> None: +        """Permanently ban user from using voice channels.""" +        await self.apply_voice_ban(ctx, user, reason) +      # endregion      # region: Temporary infractions @@ -119,6 +142,32 @@ class Infractions(InfractionScheduler, commands.Cog):          """          await self.apply_ban(ctx, user, reason, expires_at=duration) +    @command(aliases=("tempvban", "tvban")) +    async def tempvoiceban( +            self, +            ctx: Context, +            user: FetchedMember, +            duration: Expiry, +            *, +            reason: t.Optional[str] +    ) -> None: +        """ +        Temporarily voice ban a user for the given reason and duration. + +        A unit of time should be appended to the duration. +        Units (∗case-sensitive): +        \u2003`y` - years +        \u2003`m` - months∗ +        \u2003`w` - weeks +        \u2003`d` - days +        \u2003`h` - hours +        \u2003`M` - minutes∗ +        \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration. +        """ +        await self.apply_voice_ban(ctx, user, reason, expires_at=duration) +      # endregion      # region: Permanent shadow infractions @@ -208,6 +257,11 @@ class Infractions(InfractionScheduler, commands.Cog):          """Prematurely end the active ban infraction for the user."""          await self.pardon_infraction(ctx, "ban", user) +    @command(aliases=("uvban",)) +    async def unvoiceban(self, ctx: Context, user: FetchedMember) -> None: +        """Prematurely end the active voice ban infraction for the user.""" +        await self.pardon_infraction(ctx, "voice_ban", user) +      # endregion      # region: Base apply functions @@ -246,7 +300,14 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user, action)      @respect_role_hierarchy(member_arg=2) -    async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: +    async def apply_ban( +        self, +        ctx: Context, +        user: UserSnowflake, +        reason: t.Optional[str], +        purge_days: t.Optional[int] = 0, +        **kwargs +    ) -> None:          """          Apply a ban infraction with kwargs passed to `post_infraction`. @@ -278,7 +339,7 @@ class Infractions(InfractionScheduler, commands.Cog):          if reason:              reason = textwrap.shorten(reason, width=512, placeholder="...") -        action = ctx.guild.ban(user, reason=reason, delete_message_days=0) +        action = ctx.guild.ban(user, reason=reason, delete_message_days=purge_days)          await self.apply_infraction(ctx, infraction, user, action)          if infraction.get('expires_at') is not None: @@ -295,6 +356,26 @@ class Infractions(InfractionScheduler, commands.Cog):          bb_reason = "User has been permanently banned from the server. Automatically removed."          await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) +    @respect_role_hierarchy(member_arg=2) +    async def apply_voice_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: +        """Apply a voice ban infraction with kwargs passed to `post_infraction`.""" +        if await _utils.get_active_infraction(ctx, user, "voice_ban"): +            return + +        infraction = await _utils.post_infraction(ctx, user, "voice_ban", reason, active=True, **kwargs) +        if infraction is None: +            return + +        self.mod_log.ignore(Event.member_update, user.id) + +        if reason: +            reason = textwrap.shorten(reason, width=512, placeholder="...") + +        await user.move_to(None, reason="Disconnected from voice to apply voiceban.") + +        action = user.remove_roles(self._voice_verified_role, reason=reason) +        await self.apply_infraction(ctx, infraction, user, action) +      # endregion      # region: Base pardon functions @@ -339,6 +420,27 @@ class Infractions(InfractionScheduler, commands.Cog):          return log_text +    async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: +        """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" +        user = guild.get_member(user_id) +        log_text = {} + +        if user: +            # DM user about infraction expiration +            notified = await _utils.notify_pardon( +                user=user, +                title="Voice ban ended", +                content="You have been unbanned and can verify yourself again in the server.", +                icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] +            ) + +            log_text["Member"] = format_user(user) +            log_text["DM"] = "Sent" if notified else "**Failed**" +        else: +            log_text["Info"] = "User was not found in the guild." + +        return log_text +      async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. @@ -353,6 +455,8 @@ class Infractions(InfractionScheduler, commands.Cog):              return await self.pardon_mute(user_id, guild, reason)          elif infraction["type"] == "ban":              return await self.pardon_ban(user_id, guild, reason) +        elif infraction["type"] == "voice_ban": +            return await self.pardon_voice_ban(user_id, guild, reason)      # endregion diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index cdab1a6c7..394f63da3 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog  from bot.pagination import LinePaginator  from bot.utils import messages, time -from bot.utils.checks import in_whitelist_check +from bot.utils.channel import is_mod_channel  log = logging.getLogger(__name__) @@ -295,13 +295,7 @@ class ModManagement(commands.Cog):          """Only allow moderators inside moderator channels to invoke the commands in this cog."""          checks = [              await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx), -            in_whitelist_check( -                ctx, -                channels=constants.MODERATION_CHANNELS, -                categories=[constants.Categories.modmail], -                redirect=None, -                fail_silently=True, -            ) +            is_mod_channel(ctx.channel)          ]          return all(checks) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index eec63f5b3..adfe42fcd 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -135,7 +135,8 @@ class Superstarify(InfractionScheduler, Cog):              return          # Post the infraction to the API -        reason = reason or f"old nick: {member.display_name}" +        old_nick = member.display_name +        reason = reason or f"old nick: {old_nick}"          infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True)          id_ = infraction["id"] @@ -148,7 +149,7 @@ class Superstarify(InfractionScheduler, Cog):          await member.edit(nick=forced_nick, reason=reason)          self.schedule_expiration(infraction) -        old_nick = escape_markdown(member.display_name) +        old_nick = escape_markdown(old_nick)          forced_nick = escape_markdown(forced_nick)          # Send a DM to the user to notify them of their new infraction. diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index ac0c1c85e..e6712b3b6 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -1,8 +1,11 @@ -import asyncio +import json  import logging  from contextlib import suppress +from datetime import datetime, timedelta, timezone +from operator import attrgetter  from typing import Optional +from async_rediscache import RedisCache  from discord import TextChannel  from discord.ext import commands, tasks  from discord.ext.commands import Context @@ -10,10 +13,25 @@ from discord.ext.commands import Context  from bot.bot import Bot  from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles  from bot.converters import HushDurationConverter +from bot.utils.lock import LockedResourceError, lock_arg  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) +LOCK_NAMESPACE = "silence" + +MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." + +MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_MANUAL = ( +    f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " +    f"set manually or the cache was prematurely cleared. " +    f"Please edit the overwrites manually to unsilence." +) +MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." +  class SilenceNotifier(tasks.Loop):      """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -56,25 +74,32 @@ class SilenceNotifier(tasks.Loop):  class Silence(commands.Cog):      """Commands for stopping channel messages for `verified` role in a channel.""" +    # Maps muted channel IDs to their previous overwrites for send_message and add_reactions. +    # Overwrites are stored as JSON. +    previous_overwrites = RedisCache() + +    # Maps muted channel IDs to POSIX timestamps of when they'll be unsilenced. +    # A timestamp equal to -1 means it's indefinite. +    unsilence_timestamps = RedisCache() +      def __init__(self, bot: Bot):          self.bot = bot          self.scheduler = Scheduler(self.__class__.__name__) -        self.muted_channels = set() -        self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) -        self._get_instance_vars_event = asyncio.Event() +        self._init_task = self.bot.loop.create_task(self._async_init()) -    async def _get_instance_vars(self) -> None: -        """Get instance variables after they're available to get from the guild.""" +    async def _async_init(self) -> None: +        """Set instance attributes once the guild is available and reschedule unsilences."""          await self.bot.wait_until_guild_available() +          guild = self.bot.get_guild(Guild.id)          self._verified_role = guild.get_role(Roles.verified)          self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) -        self._mod_log_channel = self.bot.get_channel(Channels.mod_log) -        self.notifier = SilenceNotifier(self._mod_log_channel) -        self._get_instance_vars_event.set() +        self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) +        await self._reschedule()      @commands.command(aliases=("hush",)) +    @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True)      async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None:          """          Silence the current channel for `duration` minutes or `forever`. @@ -82,18 +107,25 @@ class Silence(commands.Cog):          Duration is capped at 15 minutes, passing forever makes the silence indefinite.          Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start.          """ -        await self._get_instance_vars_event.wait() -        log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") -        if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): -            await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") -            return -        if duration is None: -            await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") +        await self._init_task + +        channel_info = f"#{ctx.channel} ({ctx.channel.id})" +        log.debug(f"{ctx.author} is silencing channel {channel_info}.") + +        if not await self._set_silence_overwrites(ctx.channel): +            log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") +            await ctx.send(MSG_SILENCE_FAIL)              return -        await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") +        await self._schedule_unsilence(ctx, duration) -        self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) +        if duration is None: +            self.notifier.add_channel(ctx.channel) +            log.info(f"Silenced {channel_info} indefinitely.") +            await ctx.send(MSG_SILENCE_PERMANENT) +        else: +            log.info(f"Silenced {channel_info} for {duration} minute(s).") +            await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration))      @commands.command(aliases=("unhush",))      async def unsilence(self, ctx: Context) -> None: @@ -102,61 +134,115 @@ class Silence(commands.Cog):          If the channel was silenced indefinitely, notifications for the channel will stop.          """ -        await self._get_instance_vars_event.wait() +        await self._init_task          log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") -        if not await self._unsilence(ctx.channel): -            await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") +        await self._unsilence_wrapper(ctx.channel) + +    @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) +    async def _unsilence_wrapper(self, channel: TextChannel) -> None: +        """Unsilence `channel` and send a success/failure message.""" +        if not await self._unsilence(channel): +            overwrite = channel.overwrites_for(self._verified_role) +            if overwrite.send_messages is False or overwrite.add_reactions is False: +                await channel.send(MSG_UNSILENCE_MANUAL) +            else: +                await channel.send(MSG_UNSILENCE_FAIL)          else: -            await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") +            await channel.send(MSG_UNSILENCE_SUCCESS) -    async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: -        """ -        Silence `channel` for `self._verified_role`. +    async def _set_silence_overwrites(self, channel: TextChannel) -> bool: +        """Set silence permission overwrites for `channel` and return True if successful.""" +        overwrite = channel.overwrites_for(self._verified_role) +        prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) -        If `persistent` is `True` add `channel` to notifier. -        `duration` is only used for logging; if None is passed `persistent` should be True to not log None. -        Return `True` if channel permissions were changed, `False` otherwise. -        """ -        current_overwrite = channel.overwrites_for(self._verified_role) -        if current_overwrite.send_messages is False: -            log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") +        if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):              return False -        await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) -        self.muted_channels.add(channel) -        if persistent: -            log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") -            self.notifier.add_channel(channel) -            return True - -        log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") + +        overwrite.update(send_messages=False, add_reactions=False) +        await channel.set_permissions(self._verified_role, overwrite=overwrite) +        await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites)) +          return True +    async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: +        """Schedule `ctx.channel` to be unsilenced if `duration` is not None.""" +        if duration is None: +            await self.unsilence_timestamps.set(ctx.channel.id, -1) +        else: +            self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) +            unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) +            await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) +      async def _unsilence(self, channel: TextChannel) -> bool:          """          Unsilence `channel`. -        Check if `channel` is silenced through a `PermissionOverwrite`, -        if it is unsilence it and remove it from the notifier. +        If `channel` has a silence task scheduled or has its previous overwrites cached, unsilence +        it, cancel the task, and remove it from the notifier. Notify admins if it has a task but +        not cached overwrites. +          Return `True` if channel permissions were changed, `False` otherwise.          """ -        current_overwrite = channel.overwrites_for(self._verified_role) -        if current_overwrite.send_messages is False: -            await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) -            log.info(f"Unsilenced channel #{channel} ({channel.id}).") -            self.scheduler.cancel(channel.id) -            self.notifier.remove_channel(channel) -            self.muted_channels.discard(channel) -            return True -        log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") -        return False +        prev_overwrites = await self.previous_overwrites.get(channel.id) +        if channel.id not in self.scheduler and prev_overwrites is None: +            log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") +            return False + +        overwrite = channel.overwrites_for(self._verified_role) +        if prev_overwrites is None: +            log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") +            overwrite.update(send_messages=None, add_reactions=None) +        else: +            overwrite.update(**json.loads(prev_overwrites)) + +        await channel.set_permissions(self._verified_role, overwrite=overwrite) +        log.info(f"Unsilenced channel #{channel} ({channel.id}).") + +        self.scheduler.cancel(channel.id) +        self.notifier.remove_channel(channel) +        await self.previous_overwrites.delete(channel.id) +        await self.unsilence_timestamps.delete(channel.id) + +        if prev_overwrites is None: +            await self._mod_alerts_channel.send( +                f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " +                f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " +                f"overwrites for {self._verified_role.mention} are at their desired values." +            ) + +        return True + +    async def _reschedule(self) -> None: +        """Reschedule unsilencing of active silences and add permanent ones to the notifier.""" +        for channel_id, timestamp in await self.unsilence_timestamps.items(): +            channel = self.bot.get_channel(channel_id) +            if channel is None: +                log.info(f"Can't reschedule silence for {channel_id}: channel not found.") +                continue + +            if timestamp == -1: +                log.info(f"Adding permanent silence for #{channel} ({channel.id}) to the notifier.") +                self.notifier.add_channel(channel) +                continue + +            dt = datetime.fromtimestamp(timestamp, tz=timezone.utc) +            delta = (dt - datetime.now(tz=timezone.utc)).total_seconds() +            if delta <= 0: +                # Suppress the error since it's not being invoked by a user via the command. +                with suppress(LockedResourceError): +                    await self._unsilence_wrapper(channel) +            else: +                log.info(f"Rescheduling silence for #{channel} ({channel.id}).") +                self.scheduler.schedule_later(delta, channel_id, self._unsilence_wrapper(channel))      def cog_unload(self) -> None: -        """Send alert with silenced channels and cancel scheduled tasks on unload.""" -        self.scheduler.cancel_all() -        if self.muted_channels: -            channels_string = ''.join(channel.mention for channel in self.muted_channels) -            message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" -            asyncio.create_task(self._mod_alerts_channel.send(message)) +        """Cancel the init task and scheduled tasks.""" +        # It's important to wait for _init_task (specifically for _reschedule) to be cancelled +        # before cancelling scheduled tasks. Otherwise, it's possible for _reschedule to schedule +        # more tasks after cancel_all has finished, despite _init_task.cancel being called first. +        # This is cause cancel() on its own doesn't block until the task is cancelled. +        self._init_task.cancel() +        self._init_task.add_done_callback(lambda _: self.scheduler.cancel_all())      # This cannot be static (must have a __func__ attribute).      async def cog_check(self, ctx: Context) -> bool: diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index 206556483..c599156d0 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -11,6 +11,7 @@ from discord.ext.commands import Cog, Context, command, group, has_any_role  from discord.utils import snowflake_time  from bot import constants +from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.decorators import has_no_roles, in_whitelist  from bot.exts.moderation.modlog import ModLog @@ -53,6 +54,23 @@ If you'd like to unsubscribe from the announcement notifications, simply send `!  <#{constants.Channels.bot_commands}>.  """ +ALTERNATE_VERIFIED_MESSAGE = f""" +Thanks for accepting our rules! + +You can find a copy of our rules for reference at <https://pythondiscord.com/pages/rules>. + +Additionally, if you'd like to receive notifications for the announcements \ +we post in <#{constants.Channels.announcements}> +from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \ +to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement. + +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{constants.Channels.bot_commands}>. + +To introduce you to our community, we've made the following video: +https://youtu.be/ZH26PuX3re0 +""" +  # Sent via DMs to users kicked for failing to verify  KICKED_MESSAGE = f"""  Hi! You have been automatically kicked from Python Discord as you have failed to accept our rules \ @@ -156,6 +174,9 @@ class Verification(Cog):      # ]      task_cache = RedisCache() +    # Create a cache for storing recipients of the alternate welcome DM. +    member_gating_cache = RedisCache() +      def __init__(self, bot: Bot) -> None:          """Start internal tasks."""          self.bot = bot @@ -335,6 +356,28 @@ class Verification(Cog):          return n_success +    async def _add_kick_note(self, member: discord.Member) -> None: +        """ +        Post a note regarding `member` being kicked to site. + +        Allows keeping track of kicked members for auditing purposes. +        """ +        payload = { +            "active": False, +            "actor": self.bot.user.id,  # Bot actions this autonomously +            "expires_at": None, +            "hidden": True, +            "reason": "Verification kick", +            "type": "note", +            "user": member.id, +        } + +        log.trace(f"Posting kick note for member {member} ({member.id})") +        try: +            await self.bot.api_client.post("bot/infractions", json=payload) +        except ResponseCodeError as api_exc: +            log.warning("Failed to post kick note", exc_info=api_exc) +      async def _kick_members(self, members: t.Collection[discord.Member]) -> int:          """          Kick `members` from the PyDis guild. @@ -353,6 +396,7 @@ class Verification(Cog):              except discord.HTTPException as suspicious_exception:                  raise StopExecution(reason=suspicious_exception)              await member.kick(reason=f"User has not verified in {constants.Verification.kicked_after} days") +            await self._add_kick_note(member)          n_kicked = await self._send_requests(members, kick_request, Limit(batch_size=2, sleep_secs=1))          self.bot.stats.incr("verification.kicked", count=n_kicked) @@ -519,6 +563,26 @@ class Verification(Cog):          if member.guild.id != constants.Guild.id:              return  # Only listen for PyDis events +        raw_member = await self.bot.http.get_member(member.guild.id, member.id) + +        # If the user has the is_pending flag set, they will be using the alternate +        # gate and will not need a welcome DM with verification instructions. +        # We will send them an alternate DM once they verify with the welcome +        # video. +        if raw_member.get("is_pending"): +            await self.member_gating_cache.set(member.id, True) + +            # TODO: Temporary, remove soon after asking joe. +            await self.mod_log.send_log_message( +                icon_url=self.bot.user.avatar_url, +                colour=discord.Colour.blurple(), +                title="New native gated user", +                channel_id=constants.Channels.user_log, +                text=f"<@{member.id}> ({member.id})", +            ) + +            return +          log.trace(f"Sending on join message to new member: {member.id}")          try:              await safe_dm(member.send(ON_JOIN_MESSAGE)) @@ -526,6 +590,23 @@ class Verification(Cog):              log.exception("DM dispatch failed on unexpected error code")      @Cog.listener() +    async def on_member_update(self, before: discord.Member, after: discord.Member) -> None: +        """Check if we need to send a verification DM to a gated user.""" +        before_roles = [role.id for role in before.roles] +        after_roles = [role.id for role in after.roles] + +        if constants.Roles.verified not in before_roles and constants.Roles.verified in after_roles: +            if await self.member_gating_cache.pop(after.id): +                try: +                    # If the member has not received a DM from our !accept command +                    # and has gone through the alternate gating system we should send +                    # our alternate welcome DM which includes info such as our welcome +                    # video. +                    await safe_dm(after.send(ALTERNATE_VERIFIED_MESSAGE)) +                except discord.HTTPException: +                    log.exception("DM dispatch failed on unexpected error code") + +    @Cog.listener()      async def on_message(self, message: discord.Message) -> None:          """Check new message event for messages to the checkpoint channel & process."""          if message.channel.id != constants.Channels.verification: diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py new file mode 100644 index 000000000..c2743e136 --- /dev/null +++ b/bot/exts/moderation/voice_gate.py @@ -0,0 +1,168 @@ +import asyncio +import logging +from contextlib import suppress +from datetime import datetime, timedelta + +import discord +from dateutil import parser +from discord import Colour +from discord.ext.commands import Cog, Context, command + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Event, MODERATION_ROLES, Roles, VoiceGate as GateConf +from bot.decorators import has_no_roles, in_whitelist +from bot.exts.moderation.modlog import ModLog +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + +FAILED_MESSAGE = ( +    """You are not currently eligible to use voice inside Python Discord for the following reasons:\n\n{reasons}""" +) + +MESSAGE_FIELD_MAP = { +    "verified_at": f"have been verified for less than {GateConf.minimum_days_verified} days", +    "voice_banned": "have an active voice ban infraction", +    "total_messages": f"have sent less than {GateConf.minimum_messages} messages", +} + + +class VoiceGate(Cog): +    """Voice channels verification management.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @property +    def mod_log(self) -> ModLog: +        """Get the currently loaded ModLog cog instance.""" +        return self.bot.get_cog("ModLog") + +    @command(aliases=('voiceverify',)) +    @has_no_roles(Roles.voice_verified) +    @in_whitelist(channels=(Channels.voice_gate,), redirect=None) +    async def voice_verify(self, ctx: Context, *_) -> None: +        """ +        Apply to be able to use voice within the Discord server. + +        In order to use voice you must meet all three of the following criteria: +        - You must have over a certain number of messages within the Discord server +        - You must have accepted our rules over a certain number of days ago +        - You must not be actively banned from using our voice channels +        """ +        try: +            data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data") +        except ResponseCodeError as e: +            if e.status == 404: +                embed = discord.Embed( +                    title="Not found", +                    description=( +                        "We were unable to find user data for you. " +                        "Please try again shortly, " +                        "if this problem persists please contact the server staff through Modmail.", +                    ), +                    color=Colour.red() +                ) +                log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})") +            else: +                embed = discord.Embed( +                    title="Unexpected response", +                    description=( +                        "We encountered an error while attempting to find data for your user. " +                        "Please try again and let us know if the problem persists." +                    ), +                    color=Colour.red() +                ) +                log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.") + +            await ctx.author.send(embed=embed) +            return + +        # Pre-parse this for better code style +        if data["verified_at"] is not None: +            data["verified_at"] = parser.isoparse(data["verified_at"]) +        else: +            data["verified_at"] = datetime.utcnow() - timedelta(days=3) + +        checks = { +            "verified_at": data["verified_at"] > datetime.utcnow() - timedelta(days=GateConf.minimum_days_verified), +            "total_messages": data["total_messages"] < GateConf.minimum_messages, +            "voice_banned": data["voice_banned"] +        } +        failed = any(checks.values()) +        failed_reasons = [MESSAGE_FIELD_MAP[key] for key, value in checks.items() if value is True] +        [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True] + +        if failed: +            embed = discord.Embed( +                title="Voice Gate failed", +                description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)), +                color=Colour.red() +            ) +            try: +                await ctx.author.send(embed=embed) +                await ctx.send(f"{ctx.author}, please check your DMs.") +            except discord.Forbidden: +                await ctx.channel.send(ctx.author.mention, embed=embed) +            return + +        self.mod_log.ignore(Event.member_update, ctx.author.id) +        embed = discord.Embed( +            title="Voice gate passed", +            description="You have been granted permission to use voice channels in Python Discord.", +            color=Colour.green() +        ) + +        if ctx.author.voice: +            embed.description += "\n\nPlease reconnect to your voice channel to be granted your new permissions." + +        try: +            await ctx.author.send(embed=embed) +            await ctx.send(f"{ctx.author}, please check your DMs.") +        except discord.Forbidden: +            await ctx.channel.send(ctx.author.mention, embed=embed) + +        # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it. +        await asyncio.sleep(3) +        await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed") + +        self.bot.stats.incr("voice_gate.passed") + +    @Cog.listener() +    async def on_message(self, message: discord.Message) -> None: +        """Delete all non-staff messages from voice gate channel that don't invoke voice verify command.""" +        # Check is channel voice gate +        if message.channel.id != Channels.voice_gate: +            return + +        ctx = await self.bot.get_context(message) +        is_verify_command = ctx.command is not None and ctx.command.name == "voice_verify" + +        # When it's bot sent message, delete it after some time +        if message.author.bot: +            with suppress(discord.NotFound): +                await message.delete(delay=GateConf.bot_message_delete_delay) +                return + +        # Then check is member moderator+, because we don't want to delete their messages. +        if any(role.id in MODERATION_ROLES for role in message.author.roles) and is_verify_command is False: +            log.trace(f"Excluding moderator message {message.id} from deletion in #{message.channel}.") +            return + +        # Ignore deleted voice verification messages +        if ctx.command is not None and ctx.command.name == "voice_verify": +            self.mod_log.ignore(Event.message_delete, message.id) + +        with suppress(discord.NotFound): +            await message.delete() + +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Check for & ignore any InWhitelistCheckFailure.""" +        if isinstance(error, InWhitelistCheckFailure): +            error.handled = True + + +def setup(bot: Bot) -> None: +    """Loads the VoiceGate cog.""" +    bot.add_cog(VoiceGate(bot)) diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index ba1fd2a5c..69d623581 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -1,22 +1,14 @@ -import ast  import logging -import re -import time -from typing import Optional, Tuple +from typing import Optional -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord import Embed, TextChannel  from discord.ext.commands import Cog, Context, command, group, has_any_role  from bot.bot import Bot -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.exts.filters.token_remover import TokenRemover -from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE -from bot.utils.messages import wait_for_deletion +from bot.constants import Guild, MODERATION_ROLES, Roles, URLs  log = logging.getLogger(__name__) -RE_MARKDOWN = re.compile(r'([*_~`|>])') -  class BotCog(Cog, name="Bot"):      """Bot information commands.""" @@ -24,19 +16,6 @@ class BotCog(Cog, name="Bot"):      def __init__(self, bot: Bot):          self.bot = bot -        # Stores allowed channels plus epoch time since last call. -        self.channel_cooldowns = { -            Channels.python_discussion: 0, -        } - -        # These channels will also work, but will not be subject to cooldown -        self.channel_whitelist = ( -            Channels.bot_commands, -        ) - -        # Stores improperly formatted Python codeblock message ids and the corresponding bot message -        self.codeblock_message_ids = {} -      @group(invoke_without_command=True, name="bot", hidden=True)      @has_any_role(Roles.verified)      async def botinfo_group(self, ctx: Context) -> None: @@ -81,305 +60,6 @@ class BotCog(Cog, name="Bot"):          else:              await channel.send(embed=embed) -    def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: -        """ -        Strip msg in order to find Python code. - -        Tries to strip out Python code out of msg and returns the stripped block or -        None if the block is a valid Python codeblock. -        """ -        if msg.count("\n") >= 3: -            # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. -            if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: -                log.trace( -                    "Someone wrote a message that was already a " -                    "valid Python syntax highlighted code block. No action taken." -                ) -                return None - -            else: -                # Stripping backticks from every line of the message. -                log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") -                content = "" -                for line in msg.splitlines(keepends=True): -                    content += line.strip("`") - -                content = content.strip() - -                # Remove "Python" or "Py" from start of the message if it exists. -                log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") -                pycode = False -                if content.lower().startswith("python"): -                    content = content[6:] -                    pycode = True -                elif content.lower().startswith("py"): -                    content = content[2:] -                    pycode = True - -                if pycode: -                    content = content.splitlines(keepends=True) - -                    # Check if there might be code in the first line, and preserve it. -                    first_line = content[0] -                    if " " in content[0]: -                        first_space = first_line.index(" ") -                        content[0] = first_line[first_space:] -                        content = "".join(content) - -                    # If there's no code we can just get rid of the first line. -                    else: -                        content = "".join(content[1:]) - -                # Strip it again to remove any leading whitespace. This is necessary -                # if the first line of the message looked like ```python <code> -                old = content.strip() - -                # Strips REPL code out of the message if there is any. -                content, repl_code = self.repl_stripping(old) -                if old != content: -                    return (content, old), repl_code - -                # Try to apply indentation fixes to the code. -                content = self.fix_indentation(content) - -                # Check if the code contains backticks, if it does ignore the message. -                if "`" in content: -                    log.trace("Detected ` inside the code, won't reply") -                    return None -                else: -                    log.trace(f"Returning message.\n\n{content}\n\n") -                    return (content,), repl_code - -    def fix_indentation(self, msg: str) -> str: -        """Attempts to fix badly indented code.""" -        def unindent(code: str, skip_spaces: int = 0) -> str: -            """Unindents all code down to the number of spaces given in skip_spaces.""" -            final = "" -            current = code[0] -            leading_spaces = 0 - -            # Get numbers of spaces before code in the first line. -            while current == " ": -                current = code[leading_spaces + 1] -                leading_spaces += 1 -            leading_spaces -= skip_spaces - -            # If there are any, remove that number of spaces from every line. -            if leading_spaces > 0: -                for line in code.splitlines(keepends=True): -                    line = line[leading_spaces:] -                    final += line -                return final -            else: -                return code - -        # Apply fix for "all lines are overindented" case. -        msg = unindent(msg) - -        # If the first line does not end with a colon, we can be -        # certain the next line will be on the same indentation level. -        # -        # If it does end with a colon, we will need to indent all successive -        # lines one additional level. -        first_line = msg.splitlines()[0] -        code = "".join(msg.splitlines(keepends=True)[1:]) -        if not first_line.endswith(":"): -            msg = f"{first_line}\n{unindent(code)}" -        else: -            msg = f"{first_line}\n{unindent(code, 4)}" -        return msg - -    def repl_stripping(self, msg: str) -> Tuple[str, bool]: -        """ -        Strip msg in order to extract Python code out of REPL output. - -        Tries to strip out REPL Python code out of msg and returns the stripped msg. - -        Returns True for the boolean if REPL code was found in the input msg. -        """ -        final = "" -        for line in msg.splitlines(keepends=True): -            if line.startswith(">>>") or line.startswith("..."): -                final += line[4:] -        log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") -        if not final: -            log.trace(f"Found no REPL code in \n\n{msg}\n\n") -            return msg, False -        else: -            log.trace(f"Found REPL code in \n\n{msg}\n\n") -            return final.rstrip(), True - -    def has_bad_ticks(self, msg: Message) -> bool: -        """Check to see if msg contains ticks that aren't '`'.""" -        not_backticks = [ -            "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", -            "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", -            "\u3003\u3003\u3003" -        ] - -        return msg.content[:3] in not_backticks - -    @Cog.listener() -    async def on_message(self, msg: Message) -> None: -        """ -        Detect poorly formatted Python code in new messages. - -        If poorly formatted code is detected, send the user a helpful message explaining how to do -        properly formatted Python syntax highlighting codeblocks. -        """ -        is_help_channel = ( -            getattr(msg.channel, "category", None) -            and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) -        ) -        parse_codeblock = ( -            ( -                is_help_channel -                or msg.channel.id in self.channel_cooldowns -                or msg.channel.id in self.channel_whitelist -            ) -            and not msg.author.bot -            and len(msg.content.splitlines()) > 3 -            and not TokenRemover.find_token_in_message(msg) -            and not WEBHOOK_URL_RE.search(msg.content) -        ) - -        if parse_codeblock:  # no token in the msg -            on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 -            if not on_cooldown or DEBUG_MODE: -                try: -                    if self.has_bad_ticks(msg): -                        ticks = msg.content[:3] -                        content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) -                        if content is None: -                            return - -                        content, repl_code = content - -                        if len(content) == 2: -                            content = content[1] -                        else: -                            content = content[0] - -                        space_left = 204 -                        if len(content) >= space_left: -                            current_length = 0 -                            lines_walked = 0 -                            for line in content.splitlines(keepends=True): -                                if current_length + len(line) > space_left or lines_walked == 10: -                                    break -                                current_length += len(line) -                                lines_walked += 1 -                            content = content[:current_length] + "#..." -                        content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) -                        howto = ( -                            "It looks like you are trying to paste code into this channel.\n\n" -                            "You seem to be using the wrong symbols to indicate where the codeblock should start. " -                            f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" -                            "**Here is an example of how it should look:**\n" -                            f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" -                            "**This will result in the following:**\n" -                            f"```python\n{content}\n```" -                        ) - -                    else: -                        howto = "" -                        content = self.codeblock_stripping(msg.content, False) -                        if content is None: -                            return - -                        content, repl_code = content -                        # Attempts to parse the message into an AST node. -                        # Invalid Python code will raise a SyntaxError. -                        tree = ast.parse(content[0]) - -                        # Multiple lines of single words could be interpreted as expressions. -                        # This check is to avoid all nodes being parsed as expressions. -                        # (e.g. words over multiple lines) -                        if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: -                            # Shorten the code to 10 lines and/or 204 characters. -                            space_left = 204 -                            if content and repl_code: -                                content = content[1] -                            else: -                                content = content[0] - -                            if len(content) >= space_left: -                                current_length = 0 -                                lines_walked = 0 -                                for line in content.splitlines(keepends=True): -                                    if current_length + len(line) > space_left or lines_walked == 10: -                                        break -                                    current_length += len(line) -                                    lines_walked += 1 -                                content = content[:current_length] + "#..." - -                            content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) -                            howto += ( -                                "It looks like you're trying to paste code into this channel.\n\n" -                                "Discord has support for Markdown, which allows you to post code with full " -                                "syntax highlighting. Please use these whenever you paste code, as this " -                                "helps improve the legibility and makes it easier for us to help you.\n\n" -                                f"**To do this, use the following method:**\n" -                                f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" -                                "**This will result in the following:**\n" -                                f"```python\n{content}\n```" -                            ) - -                            log.debug(f"{msg.author} posted something that needed to be put inside python code " -                                      "blocks. Sending the user some instructions.") -                        else: -                            log.trace("The code consists only of expressions, not sending instructions") - -                    if howto != "": -                        # Increase amount of codeblock correction in stats -                        self.bot.stats.incr("codeblock_corrections") -                        howto_embed = Embed(description=howto) -                        bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) -                        self.codeblock_message_ids[msg.id] = bot_message.id - -                        self.bot.loop.create_task( -                            wait_for_deletion(bot_message, (msg.author.id,), self.bot) -                        ) -                    else: -                        return - -                    if msg.channel.id not in self.channel_whitelist: -                        self.channel_cooldowns[msg.channel.id] = time.time() - -                except SyntaxError: -                    log.trace( -                        f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " -                        "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " -                        f"The message that was posted was:\n\n{msg.content}\n\n" -                    ) - -    @Cog.listener() -    async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: -        """Check to see if an edited message (previously called out) still contains poorly formatted code.""" -        if ( -            # Checks to see if the message was called out by the bot -            payload.message_id not in self.codeblock_message_ids -            # Makes sure that there is content in the message -            or payload.data.get("content") is None -            # Makes sure there's a channel id in the message payload -            or payload.data.get("channel_id") is None -        ): -            return - -        # Retrieve channel and message objects for use later -        channel = self.bot.get_channel(int(payload.data.get("channel_id"))) -        user_message = await channel.fetch_message(payload.message_id) - -        #  Checks to see if the user has corrected their codeblock.  If it's fixed, has_fixed_codeblock will be None -        has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - -        # If the message is fixed, delete the bot message and the entry from the id dictionary -        if has_fixed_codeblock is None: -            bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) -            await bot_message.delete() -            del self.codeblock_message_ids[payload.message_id] -            log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") -  def setup(bot: Bot) -> None:      """Load the Bot cog.""" diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py index a9ca3dbeb..572fc934b 100644 --- a/bot/exts/utils/ping.py +++ b/bot/exts/utils/ping.py @@ -33,7 +33,7 @@ class Latency(commands.Cog):          """          # datetime.datetime objects do not have the "milliseconds" attribute.          # It must be converted to seconds before converting to milliseconds. -        bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() / 1000 +        bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() * 1000          bot_ping = f"{bot_ping:.{ROUND_LATENCY}f} ms"          try: diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bf4e24661..3113a1149 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -23,7 +23,7 @@ from bot.utils.time import humanize_delta  log = logging.getLogger(__name__) -NAMESPACE = "reminder"  # Used for the mutually_exclusive decorator; constant to prevent typos +LOCK_NAMESPACE = "reminder"  WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5 @@ -170,7 +170,7 @@ class Reminders(Cog):          log.trace(f"Scheduling new task #{reminder['id']}")          self.schedule_reminder(reminder) -    @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True)      async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:          """Send the reminder."""          is_valid, user, channel = self.ensure_valid_reminder(reminder) @@ -378,7 +378,7 @@ class Reminders(Cog):          mention_ids = [mention.id for mention in mentions]          await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) -    @lock_arg(NAMESPACE, "id_", raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)      async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:          """Edits a reminder with the given payload, then sends a confirmation message."""          if not await self._can_modify(ctx, id_): @@ -398,7 +398,7 @@ class Reminders(Cog):          await self._reschedule_reminder(reminder)      @remind_group.command("delete", aliases=("remove", "cancel")) -    @lock_arg(NAMESPACE, "id_", raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)      async def delete_reminder(self, ctx: Context, id_: int) -> None:          """Delete one of your active reminders."""          if not await self._can_modify(ctx, id_): diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index da3e07f42..41cb00541 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -36,11 +36,11 @@ RAW_CODE_REGEX = re.compile(      re.DOTALL                               # "." also matches newlines  ) -MAX_PASTE_LEN = 1000 +MAX_PASTE_LEN = 10000  # `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric, Channels.code_help_voice) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice)  EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)  SIGKILL = 9 diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 60170a88f..13533a467 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,4 +1,4 @@ -from bot.utils.helpers import CogABCMeta, find_nth_occurrence, pad_base64 +from bot.utils.helpers import CogABCMeta, find_nth_occurrence, has_lines, pad_base64  from bot.utils.services import send_to_paste_service -__all__ = ['CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] +__all__ = ['CogABCMeta', 'find_nth_occurrence', 'has_lines', 'pad_base64', 'send_to_paste_service'] diff --git a/bot/utils/channel.py b/bot/utils/channel.py new file mode 100644 index 000000000..6bf70bfde --- /dev/null +++ b/bot/utils/channel.py @@ -0,0 +1,49 @@ +import logging + +import discord + +from bot import constants +from bot.constants import Categories + +log = logging.getLogger(__name__) + + +def is_help_channel(channel: discord.TextChannel) -> bool: +    """Return True if `channel` is in one of the help categories (excluding dormant).""" +    log.trace(f"Checking if #{channel} is a help channel.") +    categories = (Categories.help_available, Categories.help_in_use) + +    return any(is_in_category(channel, category) for category in categories) + + +def is_mod_channel(channel: discord.TextChannel) -> bool: +    """True if `channel` is considered a mod channel.""" +    if channel.id in constants.MODERATION_CHANNELS: +        log.trace(f"Channel #{channel} is a configured mod channel") +        return True + +    elif any(is_in_category(channel, category) for category in constants.MODERATION_CATEGORIES): +        log.trace(f"Channel #{channel} is in a configured mod category") +        return True + +    else: +        log.trace(f"Channel #{channel} is not a mod channel") +        return False + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: +    """Return True if `channel` is within a category with `category_id`.""" +    return getattr(channel, "category_id", None) == category_id + + +async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel: +    """Attempt to get or fetch a channel and return it.""" +    log.trace(f"Getting the channel {channel_id}.") + +    channel = client.get_channel(channel_id) +    if not channel: +        log.debug(f"Channel {channel_id} is not in cache; fetching from API.") +        channel = await client.fetch_channel(channel_id) + +    log.trace(f"Channel #{channel} ({channel_id}) retrieved.") +    return channel diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py index d9b60af07..3501a3933 100644 --- a/bot/utils/helpers.py +++ b/bot/utils/helpers.py @@ -18,6 +18,15 @@ def find_nth_occurrence(string: str, substring: str, n: int) -> Optional[int]:      return index +def has_lines(string: str, count: int) -> bool: +    """Return True if `string` has at least `count` lines.""" +    # Benchmarks show this is significantly faster than using str.count("\n") or a for loop & break. +    split = string.split("\n", count - 1) + +    # Make sure the last part isn't empty, which would happen if there was a final newline. +    return split[-1] and len(split) == count + +  def pad_base64(data: str) -> str:      """Return base64 `data` with padding characters to ensure its length is a multiple of 4."""      return data + "=" * (-len(data) % 4) diff --git a/bot/utils/messages.py b/bot/utils/messages.py index d0b2342b3..b6c7cab50 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -56,15 +56,24 @@ async def wait_for_deletion(  async def send_attachments(      message: discord.Message,      destination: Union[discord.TextChannel, discord.Webhook], -    link_large: bool = True +    link_large: bool = True, +    use_cached: bool = False, +    **kwargs  ) -> List[str]:      """      Re-upload the message's attachments to the destination and return a list of their new URLs.      Each attachment is sent as a separate message to more easily comply with the request/file size      limit. If link_large is True, attachments which are too large are instead grouped into a single -    embed which links to them. +    embed which links to them. Extra kwargs will be passed to send() when sending the attachment.      """ +    webhook_send_kwargs = { +        'username': message.author.display_name, +        'avatar_url': message.author.avatar_url, +    } +    webhook_send_kwargs.update(kwargs) +    webhook_send_kwargs['username'] = sub_clyde(webhook_send_kwargs['username']) +      large = []      urls = []      for attachment in message.attachments: @@ -78,18 +87,14 @@ async def send_attachments(              # but some may get through hence the try-catch.              if attachment.size <= destination.guild.filesize_limit - 512:                  with BytesIO() as file: -                    await attachment.save(file, use_cached=True) +                    await attachment.save(file, use_cached=use_cached)                      attachment_file = discord.File(file, filename=attachment.filename)                      if isinstance(destination, discord.TextChannel): -                        msg = await destination.send(file=attachment_file) +                        msg = await destination.send(file=attachment_file, **kwargs)                          urls.append(msg.attachments[0].url)                      else: -                        await destination.send( -                            file=attachment_file, -                            username=sub_clyde(message.author.display_name), -                            avatar_url=message.author.avatar_url -                        ) +                        await destination.send(file=attachment_file, **webhook_send_kwargs)              elif link_large:                  large.append(attachment)              else: @@ -106,13 +111,9 @@ async def send_attachments(          embed.set_footer(text="Attachments exceed upload size limit.")          if isinstance(destination, discord.TextChannel): -            await destination.send(embed=embed) +            await destination.send(embed=embed, **kwargs)          else: -            await destination.send( -                embed=embed, -                username=sub_clyde(message.author.display_name), -                avatar_url=message.author.avatar_url -            ) +            await destination.send(embed=embed, **webhook_send_kwargs)      return urls diff --git a/config-default.yml b/config-default.yml index 4f7b1e217..071f6e1ec 100644 --- a/config-default.yml +++ b/config-default.yml @@ -119,6 +119,7 @@ style:          voice_state_green: "https://cdn.discordapp.com/emojis/656899770094452754.png"          voice_state_red: "https://cdn.discordapp.com/emojis/656899769905709076.png" +  guild:      id: 267624335836053506      invite: "https://discord.gg/python" @@ -127,7 +128,9 @@ guild:          help_available:                     691405807388196926          help_in_use:                        696958401460043776          help_dormant:                       691405908919451718 -        modmail:                            714494672835444826 +        modmail:            &MODMAIL        714494672835444826 +        logs:               &LOGS           468520609152892958 +        voice:                              356013253765234688      channels:          # Public announcement and news channels @@ -145,8 +148,8 @@ guild:          dev_log:            &DEV_LOG        622895325144940554          # Discussion -        meta:               429409067623251969 -        python_discussion:  267624335836053506 +        meta:                               429409067623251969 +        python_discussion:  &PY_DISCUSSION  267624335836053506          # Python Help: Available          how_to_get_help:    704250143020417084 @@ -169,6 +172,7 @@ guild:          bot_commands:       &BOT_CMD        267659945086812160          esoteric:                           470884583684964352          verification:                       352442727016693763 +        voice_gate:                         764802555427029012          # Staff          admins:             &ADMINS         365960823622991872 @@ -178,7 +182,7 @@ guild:          incidents:                          714214212200562749          incidents_archive:                  720668923636351037          mods:               &MODS           305126844661760000 -        mod_alerts:         &MOD_ALERTS     473092532147060736 +        mod_alerts:                         473092532147060736          mod_spam:           &MOD_SPAM       620607373828030464          organisation:       &ORGANISATION   551789653284356126          staff_lounge:       &STAFF_LOUNGE   464905259261755392 @@ -191,6 +195,8 @@ guild:          # Voice          code_help_voice:                    755154969761677312 +        code_help_voice_2:                  766330079135268884 +        voice_chat:                         412357430186344448          admins_voice:       &ADMINS_VOICE   500734494840717332          staff_voice:        &STAFF_VOICE    412375055910043655 @@ -198,10 +204,13 @@ guild:          big_brother_logs:   &BB_LOGS        468507907357409333          talent_pool:        &TALENT_POOL    534321732593647616 +    moderation_categories: +        - *MODMAIL +        - *LOGS +      moderation_channels:          - *ADMINS          - *ADMIN_SPAM -        - *MOD_ALERTS          - *MODS          - *MOD_SPAM @@ -225,9 +234,11 @@ guild:          muted:              &MUTED_ROLE         277914926603829249          partners:                               323426753857191936          python_community:   &PY_COMMUNITY_ROLE  458226413825294336 +        sprinters:          &SPRINTERS          758422482289426471          unverified:                             739794855945044069          verified:                               352427296948486144  # @Developers on PyDis +        voice_verified:                         764802720779337729          # Staff          admins:             &ADMINS_ROLE    267628507062992896 @@ -261,6 +272,7 @@ guild:          reddit:                             635408384794951680          talent_pool:                        569145364800602132 +  filter:      # What do we filter?      filter_zalgo:          false @@ -298,6 +310,7 @@ filter:          - *OWNERS_ROLE          - *HELPERS_ROLE          - *PY_COMMUNITY_ROLE +        - *SPRINTERS  keys: @@ -326,6 +339,7 @@ urls:      bot_avatar:      "https://raw.githubusercontent.com/discord-python/branding/master/logos/logo_circle/logo_circle.png"      github_bot_repo: "https://github.com/python-discord/bot" +  anti_spam:      # Clean messages that violate a rule.      clean_offending: true @@ -394,6 +408,23 @@ big_brother:      header_message_limit: 15 +code_block: +    # The channels in which code blocks will be detected. They are not subject to a cooldown. +    channel_whitelist: +        - *BOT_CMD + +    # The channels which will be affected by a cooldown. These channels are also whitelisted. +    cooldown_channels: +        - *PY_DISCUSSION + +    # Sending instructions triggers a cooldown on a per-channel basis. +    # More instruction messages will not be sent in the same channel until the cooldown has elapsed. +    cooldown_seconds: 300 + +    # The minimum amount of lines a message or code block must have for instructions to be sent. +    minimum_lines: 4 + +  free:      # Seconds to elapse for a channel      # to be considered inactive. @@ -442,10 +473,12 @@ help_channels:      notify_roles:          - *HELPERS_ROLE +  redirect_output:      delete_invocation: true      delete_delay: 15 +  duck_pond:      threshold: 4      channel_blacklist: @@ -461,11 +494,13 @@ duck_pond:          - *MOD_ANNOUNCEMENTS          - *ADMIN_ANNOUNCEMENTS +  python_news:      mail_lists:          - 'python-ideas'          - 'python-announce-list'          - 'pypi-announce' +        - 'python-dev'      channel: *PYNEWS_CHANNEL      webhook: *PYNEWS_WEBHOOK @@ -482,5 +517,11 @@ verification:      kick_confirmation_threshold: 0.01  # 1% +voice_gate: +    minimum_days_verified: 3  # How many days the user must have been verified for +    minimum_messages: 50  # How many messages a user must have to be eligible for voice +    bot_message_delete_delay: 10  # Seconds before deleting bot's response in Voice Gate + +  config:      required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index cff7d33d6..8be5aac0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -41,6 +41,7 @@ services:        - postgres      environment:        DATABASE_URL: postgres://pysite:pysite@postgres:5432/pysite +      METRICITY_DB_URL: postgres://pysite:pysite@postgres:5432/metricity        SECRET_KEY: suitable-for-development-only        STATIC_ROOT: /var/www/static diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): +    """Skips adding patchings as args if their `dont_pass` attribute is True.""" +    # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. +    extra_args = [] +    with contextlib.ExitStack() as exit_stack: +        for patching in patched.patchings: +            arg = exit_stack.enter_context(patching) +            if not getattr(patching, "dont_pass", False): +                # Only add the patching as an arg if dont_pass is False. +                if patching.attribute_name is not None: +                    keywargs.update(arg) +                elif patching.new is unittest.mock.DEFAULT: +                    extra_args.append(arg) + +        args += tuple(extra_args) +        yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): +    """Copy the `dont_pass` attribute along with the standard copy operation.""" +    patcher_copy = _copy.original(self) +    patcher_copy.dont_pass = getattr(self, "dont_pass", False) +    return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: +    """ +    Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + +    If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. +    """ +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = dict(spec_set=True, autospec=True) +    kwargs.update(patch_kwargs) + +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            if not pass_mocks: +                # A custom attribute to keep track of which patchings should be skipped. +                patcher.dont_pass = True +            func = patcher(func) +        return func +    return decorator diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index c0a1da35c..9f380a15d 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,7 +1,6 @@  import unittest -from unittest import mock -from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User +from bot.exts.backend.sync._syncers import UserSyncer, _Diff  from tests import helpers @@ -10,7 +9,7 @@ def fake_user(**kwargs):      kwargs.setdefault("id", 43)      kwargs.setdefault("name", "bob the test man")      kwargs.setdefault("discriminator", 1337) -    kwargs.setdefault("roles", (666,)) +    kwargs.setdefault("roles", [666])      kwargs.setdefault("in_guild", True)      return kwargs @@ -40,22 +39,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          return guild +    @staticmethod +    def get_mock_member(member: dict): +        member = member.copy() +        del member["in_guild"] +        mock_member = helpers.MockMember(**member) +        mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] +        return mock_member +      async def test_empty_diff_for_no_users(self):          """When no users are given, an empty diff should be returned.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [] +        }          guild = self.get_guild()          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_identical_users(self):          """No differences should be found if the users in the guild and DB are identical.""" -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.return_value = self.get_mock_member(fake_user())          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -63,59 +82,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):          """Only updated users should be added to the 'updated' set of the diff."""          updated_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(id=99, name="old"), fake_user()] +        }          guild = self.get_guild(updated_user, fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(updated_user), +            self.get_mock_member(fake_user()) +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**updated_user)}, None) +        expected_diff = ([], [{"id": 99, "name": "new"}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_users(self): -        """Only new users should be added to the 'created' set of the diff.""" +        """Only new users should be added to the 'created' list of the diff."""          new_user = fake_user(id=99, name="new") -        self.bot.api_client.get.return_value = [fake_user()] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user()] +        }          guild = self.get_guild(fake_user(), new_user) - +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(new_user) +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, set(), None) +        expected_diff = ([new_user], [], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_sets_in_guild_false_for_leaving_users(self):          """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" -        leaving_user = fake_user(id=63, in_guild=False) - -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), {_User(**leaving_user)}, None) +        expected_diff = ([], [{"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_diff_for_new_updated_and_leaving_users(self):          """When users are added, updated, and removed, all of them are returned properly."""          new_user = fake_user(id=99, name="new") +          updated_user = fake_user(id=55, name="updated") -        leaving_user = fake_user(id=63, in_guild=False) -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=55), fake_user(id=63)] +        }          guild = self.get_guild(fake_user(), new_user, updated_user) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            self.get_mock_member(updated_user), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) +        expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)          self.assertEqual(actual_diff, expected_diff)      async def test_empty_diff_for_db_users_not_in_guild(self): -        """When the DB knows a user the guild doesn't, no difference is found.""" -        self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] +        """When the DB knows a user, but the guild doesn't, no difference is found.""" +        self.bot.api_client.get.return_value = { +            "count": 3, +            "next_page_no": None, +            "previous_page_no": None, +            "results": [fake_user(), fake_user(id=63, in_guild=False)] +        }          guild = self.get_guild(fake_user()) +        guild.get_member.side_effect = [ +            self.get_mock_member(fake_user()), +            None +        ]          actual_diff = await self.syncer._get_diff(guild) -        expected_diff = (set(), set(), None) +        expected_diff = ([], [], None)          self.assertEqual(actual_diff, expected_diff) @@ -131,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          """Only POST requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(user_tuples, set(), None) +        diff = _Diff(users, [], None)          await self.syncer._sync(diff) -        calls = [mock.call("bot/users", json=user) for user in users] -        self.bot.api_client.post.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.post.call_count, len(users)) +        self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)          self.bot.api_client.put.assert_not_called()          self.bot.api_client.delete.assert_not_called() @@ -146,13 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):          """Only PUT requests should be made with the correct payload."""          users = [fake_user(id=111), fake_user(id=222)] -        user_tuples = {_User(**user) for user in users} -        diff = _Diff(set(), user_tuples, None) +        diff = _Diff([], users, None)          await self.syncer._sync(diff) -        calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] -        self.bot.api_client.put.assert_has_calls(calls, any_order=True) -        self.assertEqual(self.bot.api_client.put.call_count, len(users)) +        self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)          self.bot.api_client.post.assert_not_called()          self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index 36a35c8e2..daede54c5 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -92,77 +92,6 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(admin_embed.title, "Admins info")          self.assertEqual(admin_embed.colour, discord.Colour.red()) -    @unittest.mock.patch('bot.exts.info.information.time_since') -    async def test_server_info_command(self, time_since_patch): -        time_since_patch.return_value = '2 days ago' - -        self.ctx.guild = helpers.MockGuild( -            features=('lemons', 'apples'), -            region="The Moon", -            roles=[self.moderator_role], -            channels=[ -                discord.TextChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} -                ), -                discord.CategoryChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} -                ), -                discord.VoiceChannel( -                    state={}, -                    guild=self.ctx.guild, -                    data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} -                ) -            ], -            members=[ -                *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), -                *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), -                *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), -                *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), -            ], -            member_count=1_234, -            icon_url='a-lemon.jpg', -        ) - -        self.assertIsNone(await self.cog.server_info(self.cog, self.ctx)) - -        time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') -        _, kwargs = self.ctx.send.call_args -        embed = kwargs.pop('embed') -        self.assertEqual(embed.colour, discord.Colour.blurple()) -        self.assertEqual( -            embed.description, -            textwrap.dedent( -                f""" -                **Server information** -                Created: {time_since_patch.return_value} -                Voice region: {self.ctx.guild.region} -                Features: {', '.join(self.ctx.guild.features)} - -                **Channel counts** -                Category channels: 1 -                Text channels: 1 -                Voice channels: 1 -                Staff channels: 0 - -                **Member counts** -                Members: {self.ctx.guild.member_count:,} -                Staff members: 0 -                Roles: {len(self.ctx.guild.roles)} - -                **Member statuses** -                {constants.Emojis.status_online} 2 -                {constants.Emojis.status_idle} 1 -                {constants.Emojis.status_dnd} 4 -                {constants.Emojis.status_offline} 3 -                """ -            ) -        ) -        self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') -  class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):      """Tests for the helper methods of the `!user` command.""" @@ -465,7 +394,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          self.assertEqual(              "basic infractions info", -            embed.fields[3].value +            embed.fields[2].value          )      @unittest.mock.patch( diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index be1b649e1..bf557a484 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,7 +1,8 @@  import textwrap  import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from bot.constants import Event  from bot.exts.moderation.infraction.infractions import Infractions  from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value          ) + + +@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) +class VoiceBanTests(unittest.IsolatedAsyncioTestCase): +    """Tests for voice ban related functions and commands.""" + +    def setUp(self): +        self.bot = MockBot() +        self.mod = MockMember(top_role=10) +        self.user = MockMember(top_role=1, roles=[MockRole(id=123456)]) +        self.guild = MockGuild() +        self.ctx = MockContext(bot=self.bot, author=self.mod) +        self.cog = Infractions(self.bot) + +    async def test_permanent_voice_ban(self): +        """Should call voice ban applying function without expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar") + +    async def test_temporary_voice_ban(self): +        """Should call voice ban applying function with expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + +    async def test_voice_unban(self): +        """Should call infraction pardoning function.""" +        self.cog.pardon_infraction = AsyncMock() +        self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) +        self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): +        """Should return early when user already have Voice Ban infraction.""" +        get_active_infraction.return_value = {"foo": "bar"} +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") +        post_infraction_mock.assert_not_awaited() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): +        """Should return early when posting infraction fails.""" +        self.cog.mod_log.ignore = MagicMock() +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        post_infraction_mock.assert_awaited_once() +        self.cog.mod_log.ignore.assert_not_called() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): +        """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" +        get_active_infraction.return_value = None +        # We don't want that this continue yet +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) +        post_infraction_mock.assert_awaited_once_with( +            self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 +        ) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar") +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): +        """Should truncate reason for voice ban.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) +        self.user.remove_roles.assert_called_once_with( +            self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...") +        ) +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    async def test_voice_unban_user_not_found(self): +        """Should include info to return dict when user was not found from guild.""" +        self.guild.get_member.return_value = None +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, {"Info": "User was not found in the guild."}) + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = True +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "Sent" +        }) +        notify_pardon_mock.assert_awaited_once() + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = False +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "**Failed**" +        }) +        notify_pardon_mock.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 3c2d52ae0..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio  import unittest +from datetime import datetime, timezone  from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule():  # noqa: N802 +    """Create and connect to the fakeredis session.""" +    global redis_session +    redis_session = RedisSession(use_fakeredis=True) +    redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule():  # noqa: N802 +    """Close the fakeredis session.""" +    if redis_session: +        redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): +    """A datetime object with a mocked now() function.""" + +    now = mock.create_autospec(datetime, "now")  class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.alert_channel = MockTextChannel() -        self.notifier = SilenceNotifier(self.alert_channel) +        self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock()          self.notifier.start = self.notifier_start_mock = Mock()      def test_add_channel_adds_channel(self): -        """Channel in FirstHash with current loop is added to internal set.""" +        """Channel is added to `_silenced_channels` with the current loop."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):          self.notifier_start_mock.assert_not_called()      def test_remove_channel_removes_channel(self): -        """Channel in FirstHash is removed from `_silenced_channels`.""" +        """Channel is removed from `_silenced_channels`."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):              with self.subTest(current_loop=current_loop):                  with mock.patch.object(self.notifier, "_current_loop", new=current_loop):                      await self.notifier._notifier() -                self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") +                self.alert_channel.send.assert_called_once_with( +                    f"<@&{Roles.moderators}> currently silenced channels: " +                )              self.alert_channel.send.reset_mock()      async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):                      self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the general functionality of the Silence cog.""" + +    @autospec(silence, "Scheduler", pass_mocks=False)      def setUp(self) -> None:          self.bot = MockBot() -        self.cog = Silence(self.bot) -        self.ctx = MockContext() -        self.cog._verified_role = None -        # Set event so command callbacks can continue. -        self.cog._get_instance_vars_event.set() +        self.cog = silence.Silence(self.bot) -    async def test_instance_vars_got_guild(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog._get_instance_vars() -        self.bot.wait_until_guild_available.assert_called_once() +        await self.cog._async_init() +        self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id) -    async def test_instance_vars_got_role(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_role(self):          """Got `Roles.verified` role from guild.""" -        await self.cog._get_instance_vars()          guild = self.bot.get_guild() -        guild.get_role.assert_called_once_with(Roles.verified) +        guild.get_role.side_effect = lambda id_: Mock(id=id_) -    async def test_instance_vars_got_channels(self): +        await self.cog._async_init() +        self.assertEqual(self.cog._verified_role.id, Roles.verified) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_channels(self):          """Got channels from bot.""" -        await self.cog._get_instance_vars() -        self.bot.get_channel.called_once_with(Channels.mod_alerts) -        self.bot.get_channel.called_once_with(Channels.mod_log) +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) -    @mock.patch("bot.exts.moderation.silence.SilenceNotifier") -    async def test_instance_vars_got_notifier(self, notifier): +    @autospec(silence, "SilenceNotifier") +    async def test_async_init_got_notifier(self, notifier):          """Notifier was started with channel.""" -        mod_log = MockTextChannel() -        self.bot.get_channel.side_effect = (None, mod_log) -        await self.cog._get_instance_vars() -        notifier.assert_called_once_with(mod_log) -        self.bot.get_channel.side_effect = None - -    async def test_silence_sent_correct_discord_message(self): -        """Check if proper message was sent when called with duration in channel with previous state.""" +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) +        self.assertEqual(self.cog.notifier, notifier.return_value) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_rescheduled(self): +        """`_reschedule_` coroutine was awaited.""" +        self.cog._reschedule = mock.create_autospec(self.cog._reschedule) +        await self.cog._async_init() +        self.cog._reschedule.assert_awaited_once_with() + +    def test_cog_unload_cancelled_tasks(self): +        """The init task was cancelled.""" +        self.cog._init_task = asyncio.Future() +        self.cog.cog_unload() + +        # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. +        self.assertTrue(self.cog._init_task.cancelled()) + +    @autospec("discord.ext.commands", "has_any_role") +    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    async def test_cog_check(self, role_check): +        """Role check was called with `MODERATION_ROLES`""" +        ctx = MockContext() +        role_check.return_value.predicate = mock.AsyncMock() + +        await self.cog.cog_check(ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the rescheduling of cached unsilences.""" + +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self): +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + +        with mock.patch.object(self.cog, "_reschedule", autospec=True): +            asyncio.run(self.cog._async_init())  # Populate instance attributes. + +    async def test_skipped_missing_channel(self): +        """Did nothing because the channel couldn't be retrieved.""" +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] +        self.bot.get_channel.return_value = None + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_added_permanent_to_notifier(self): +        """Permanently silenced channels were added to the notifier.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_any_call(channels[0]) +        self.cog.notifier.add_channel.assert_any_call(channels[1]) + +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_unsilenced_expired(self): +        """Unsilenced expired silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + +        await self.cog._reschedule() + +        self.cog._unsilence_wrapper.assert_any_call(channels[0]) +        self.cog._unsilence_wrapper.assert_any_call(channels[1]) + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    @mock.patch.object(silence, "datetime", new=PatchedDatetime) +    async def test_rescheduled_active(self): +        """Rescheduled active silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] +        silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + +        self.cog._unsilence_wrapper = mock.MagicMock() +        unsilence_return = self.cog._unsilence_wrapper.return_value + +        await self.cog._reschedule() + +        # Yuck. +        calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] +        self.cog.scheduler.schedule_later.assert_has_calls(calls) + +        unsilence_calls = [mock.call(channel) for channel in channels] +        self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + +        self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the silence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        # Avoid unawaited coroutine warnings. +        self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command."""          test_cases = ( -            (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), -            (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), -            (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), +            (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), +            (None, silence.MSG_SILENCE_PERMANENT, True,), +            (5, silence.MSG_SILENCE_FAIL, False,),          ) -        for duration, result_message, _silence_patch_return in test_cases: -            with self.subTest( -                silence_duration=duration, -                result_message=result_message, -                starting_unsilenced_state=_silence_patch_return -            ): -                with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): -                    await self.cog.silence(self.cog, self.ctx, duration) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_unsilence_sent_correct_discord_message(self): -        """Check if proper message was sent when unsilencing channel.""" -        test_cases = ( -            (True, f"{Emojis.check_mark} unsilenced current channel."), -            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        for duration, message, was_silenced in test_cases: +            ctx = MockContext() +            with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): +                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): +                    await self.cog.silence.callback(self.cog, ctx, duration) +                    ctx.send.assert_called_once_with(message) + +    async def test_skipped_already_silenced(self): +        """Permissions were not set and `False` was returned for an already silenced channel.""" +        subtests = ( +            (False, PermissionOverwrite(send_messages=False, add_reactions=False)), +            (True, PermissionOverwrite(send_messages=True, add_reactions=True)), +            (True, PermissionOverwrite(send_messages=False, add_reactions=False)),          ) -        for _unsilence_patch_return, result_message in test_cases: -            with self.subTest( -                starting_silenced_state=_unsilence_patch_return, -                result_message=result_message -            ): -                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): -                    await self.cog.unsilence(self.cog, self.ctx) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_silence_private_for_false(self): -        """Permissions are not set and `False` is returned in an already silenced channel.""" -        perm_overwrite = Mock(send_messages=False) -        channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - -        self.assertFalse(await self.cog._silence(channel, True, None)) -        channel.set_permissions.assert_not_called() -    async def test_silence_private_silenced_channel(self): -        """Channel had `send_message` permissions revoked.""" -        channel = MockTextChannel() -        self.assertTrue(await self.cog._silence(channel, False, None)) -        channel.set_permissions.assert_called_once() -        self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) +        for contains, overwrite in subtests: +            with self.subTest(contains=contains, overwrite=overwrite): +                self.cog.scheduler.__contains__.return_value = contains +                channel = MockTextChannel() +                channel.overwrites_for.return_value = overwrite + +                self.assertFalse(await self.cog._set_silence_overwrites(channel)) +                channel.set_permissions.assert_not_called() + +    async def test_silenced_channel(self): +        """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) +        self.assertFalse(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite +        ) -    async def test_silence_private_preserves_permissions(self): -        """Previous permissions were preserved when channel was silenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite() -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._silence(channel, False, None) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    async def test_silence_private_notifier(self): -        """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" -        channel = MockTextChannel() -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=True): -                await self.cog._silence(channel, True, None) -                self.cog.notifier.add_channel.assert_called_once() - -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=False): -                await self.cog._silence(channel, False, None) -                self.cog.notifier.add_channel.assert_not_called() - -    async def test_silence_private_added_muted_channel(self): -        """Channel was added to `muted_channels` on silence.""" +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed.""" +        prev_overwrite_dict = dict(self.overwrite) +        await self.cog._set_silence_overwrites(self.channel) +        new_overwrite_dict = dict(self.overwrite) + +        # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. +        del prev_overwrite_dict['send_messages'] +        del prev_overwrite_dict['add_reactions'] +        del new_overwrite_dict['send_messages'] +        del new_overwrite_dict['add_reactions'] + +        self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_temp_not_added_to_notifier(self): +        """Channel was not added to notifier if a duration was set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_indefinite_added_to_notifier(self): +        """Channel was added to notifier if a duration was not set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), None) +            self.cog.notifier.add_channel.assert_called_once() + +    async def test_silenced_not_added_to_notifier(self): +        """Channel was not added to the notifier if it was already silenced.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_cached_previous_overwrites(self): +        """Channel's previous overwrites were cached.""" +        overwrite_json = '{"send_messages": true, "add_reactions": false}' +        await self.cog._set_silence_overwrites(self.channel) +        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + +    @autospec(silence, "datetime") +    async def test_cached_unsilence_time(self, datetime_mock): +        """The UTC POSIX timestamp for the unsilence was cached.""" +        now_timestamp = 100 +        duration = 15 +        timestamp = now_timestamp + duration * 60 +        datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, duration) + +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) +        datetime_mock.now.assert_called_once_with(tz=timezone.utc)  # Ensure it's using an aware dt. + +    async def test_cached_indefinite_time(self): +        """A value of -1 was cached for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + +    async def test_scheduled_task(self): +        """An unsilence task was scheduled.""" +        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + +        await self.cog.silence.callback(self.cog, ctx, 5) + +        args = (300, ctx.channel.id, ctx.invoke.return_value) +        self.cog.scheduler.schedule_later.assert_called_once_with(*args) +        ctx.invoke.assert_called_once_with(self.cog.unsilence) + +    async def test_permanent_not_scheduled(self): +        """A task was not scheduled for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the unsilence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.cog.scheduler.__contains__.return_value = True +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command.""" +        unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) +        test_cases = ( +            (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), +        ) +        for was_unsilenced, message, overwrite in test_cases: +            ctx = MockContext() +            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): +                with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): +                    ctx.channel.overwrites_for.return_value = overwrite +                    await self.cog.unsilence.callback(self.cog, ctx) +                    ctx.channel.send.assert_called_once_with(message) + +    async def test_skipped_already_unsilenced(self): +        """Permissions were not set and `False` was returned for an already unsilenced channel.""" +        self.cog.scheduler.__contains__.return_value = False +        self.cog.previous_overwrites.get.return_value = None          channel = MockTextChannel() -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._silence(channel, False, None) -        muted_channels.add.assert_called_once_with(channel) -    async def test_unsilence_private_for_false(self): -        """Permissions are not set and `False` is returned in an unsilenced channel.""" -        channel = Mock()          self.assertFalse(await self.cog._unsilence(channel))          channel.set_permissions.assert_not_called() -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_unsilenced_channel(self, _): -        """Channel had `send_message` permissions restored""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        self.assertTrue(await self.cog._unsilence(channel)) -        channel.set_permissions.assert_called_once() -        self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_notifier(self, notifier): -        """Channel was removed from `notifier` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        await self.cog._unsilence(channel) -        notifier.remove_channel.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_muted_channel(self, _): -        """Channel was removed from `muted_channels` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._unsilence(channel) -        muted_channels.discard.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_preserves_permissions(self, _): -        """Previous permissions were preserved when channel was unsilenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite(send_messages=False) -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._unsilence(channel) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    @mock.patch.object(Silence, "_mod_alerts_channel", create=True) -    def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): -        """Task for sending an alert was created with present `muted_channels`.""" -        with mock.patch.object(self.cog, "muted_channels"): -            self.cog.cog_unload() -            alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") -            asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    def test_cog_unload_skips_task_start(self, asyncio_mock): -        """No task created with no channels.""" -        self.cog.cog_unload() -        asyncio_mock.create_task.assert_not_called() +    async def test_restored_overwrites(self): +        """Channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) -    @mock.patch("discord.ext.commands.has_any_role") -    @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) -    async def test_cog_check(self, role_check): -        """Role check is called with `MODERATION_ROLES`""" -        role_check.return_value.predicate = mock.AsyncMock() -        await self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(*(1, 2, 3)) -        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) +        # Recall that these values are determined by the fixture. +        self.assertTrue(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) + +    async def test_cache_miss_used_default_overwrites(self): +        """Both overwrites were set to None due previous values not being found in the cache.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) + +        self.assertIsNone(self.overwrite.send_messages) +        self.assertIsNone(self.overwrite.add_reactions) + +    async def test_cache_miss_sent_mod_alert(self): +        """A message was sent to the mod alerts channel.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.cog._mod_alerts_channel.send.assert_awaited_once() + +    async def test_removed_notifier(self): +        """Channel was removed from `notifier`.""" +        await self.cog._unsilence(self.channel) +        self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + +    async def test_deleted_cached_overwrite(self): +        """Channel was deleted from the overwrites cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + +    async def test_deleted_cached_time(self): +        """Channel was deleted from the timestamp cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + +    async def test_cancelled_task(self): +        """The scheduled unsilence task should be cancelled.""" +        await self.cog._unsilence(self.channel) +        self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed, including cache misses.""" +        for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): +            with self.subTest(overwrite_json=overwrite_json): +                self.cog.previous_overwrites.get.return_value = overwrite_json + +                prev_overwrite_dict = dict(self.overwrite) +                await self.cog._unsilence(self.channel) +                new_overwrite_dict = dict(self.overwrite) + +                # Remove these keys because they were modified by the unsilence. +                del prev_overwrite_dict['send_messages'] +                del prev_overwrite_dict['add_reactions'] +                del new_overwrite_dict['send_messages'] +                del new_overwrite_dict['add_reactions'] + +                self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/helpers.py b/tests/helpers.py index e47fdf28f..870f66197 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools  import logging  import unittest.mock  from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional  import discord  from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context  from bot.api import APIClient  from bot.async_stats import AsyncStatsClient  from bot.bot import Bot +from tests._autospec import autospec  # noqa: F401 other modules import it via this module  for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: -    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" -    # Caller's kwargs should take priority and overwrite the defaults. -    kwargs = {'spec_set': True, 'autospec': True, **kwargs} - -    # Import the target if it's a string. -    # This is to support both object and string targets like patch.multiple. -    if type(target) is str: -        target = unittest.mock._importer(target) - -    def decorator(func): -        for attribute in attributes: -            patcher = unittest.mock.patch.object(target, attribute, **kwargs) -            func = patcher(func) -        return func -    return decorator - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.  |