diff options
| author | 2020-11-07 09:24:36 +0200 | |
|---|---|---|
| committer | 2020-11-07 09:24:36 +0200 | |
| commit | 7301ebc6c687358ae4fc3496ee8cb9a59cec8f71 (patch) | |
| tree | 7a32b64494fa58a445ece5d51391980d73e95d9d | |
| parent | Merge remote-tracking branch 'origin/bug-fixes' into bug-fixes (diff) | |
| parent | Merge pull request #1271 from spacecraft1013/code_instructions (diff) | |
Merge branch 'master' into bug-fixes
41 files changed, 2173 insertions, 811 deletions
| diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY new file mode 100644 index 000000000..eacd9b952 --- /dev/null +++ b/LICENSE-THIRD-PARTY @@ -0,0 +1,88 @@ +--------------------------------------------------------------------------------------------------- +                                       BSD 3-Clause License +Applies to: +    - Copyright (c) 2008-Present, IPython Development Team +      Copyright (c) 2001-2007, Fernando Perez <[email protected]> +      Copyright (c) 2001, Janko Hauser <[email protected]> +      Copyright (c) 2001, Nathaniel Gray <[email protected]> +      All rights reserved. +        - bot/exts/info/codeblock/_parsing.py: _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL +--------------------------------------------------------------------------------------------------- + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this +  list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, +  this list of conditions and the following disclaimer in the documentation +  and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its +  contributors may be used to endorse or promote products derived from +  this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--------------------------------------------------------------------------------------------------- +                           PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +Applies to: +    - Copyright © 2001-2020 Python Software Foundation. All rights reserved. +        - tests/_autospec.py: _decoration_helper +--------------------------------------------------------------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; +All Rights Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis.  PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED.  BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee.  This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. @@ -1,6 +1,6 @@  # Python Utility Bot -[](https://discord.gg/2B963hn) +[](https://discord.gg/2B963hn)  [](https://dev.azure.com/python-discord/Python%20Discord/_build/latest?definitionId=1&branchName=master)  [](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)  [](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master) diff --git a/bot/__init__.py b/bot/__init__.py index 3ee70c4e9..4fce04532 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -64,7 +64,7 @@ coloredlogs.install(logger=root_log, stream=sys.stdout)  logging.getLogger("discord").setLevel(logging.WARNING)  logging.getLogger("websockets").setLevel(logging.WARNING)  logging.getLogger("chardet").setLevel(logging.WARNING) -logging.getLogger(__name__) +logging.getLogger("async_rediscache").setLevel(logging.WARNING)  # On Windows, the selector event loop is required for aiodns. diff --git a/bot/__main__.py b/bot/__main__.py index da042a5ed..367be1300 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -58,7 +58,7 @@ bot = Bot(      redis_session=redis_session,      loop=loop,      command_prefix=when_mentioned_or(constants.Bot.prefix), -    activity=discord.Game(name="Commands: !help"), +    activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),      case_insensitive=True,      max_messages=10_000,      allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), diff --git a/bot/constants.py b/bot/constants.py index c21fd52e0..4d41f4eb2 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -377,6 +377,7 @@ class Categories(metaclass=YAMLGetter):      help_in_use: int      help_dormant: int      modmail: int +    voice: int  class Channels(metaclass=YAMLGetter): @@ -392,6 +393,7 @@ class Channels(metaclass=YAMLGetter):      bot_commands: int      change_log: int      code_help_voice: int +    code_help_voice_2: int      cooldown: int      defcon: int      dev_contrib: int @@ -424,6 +426,8 @@ class Channels(metaclass=YAMLGetter):      user_event_announcements: int      user_log: int      verification: int +    voice_chat: int +    voice_gate: int      voice_log: int @@ -456,9 +460,11 @@ class Roles(metaclass=YAMLGetter):      owners: int      partners: int      python_community: int +    sprinters: int      team_leaders: int      unverified: int      verified: int  # This is the Developers role on PyDis, here named verified for readability reasons. +    voice_verified: int  class Guild(metaclass=YAMLGetter): @@ -467,6 +473,7 @@ class Guild(metaclass=YAMLGetter):      id: int      invite: str  # Discord invite, gets embedded in chat      moderation_channels: List[int] +    moderation_categories: List[int]      moderation_roles: List[int]      modlog_blacklist: List[int]      reminder_whitelist: List[int] @@ -528,6 +535,15 @@ class BigBrother(metaclass=YAMLGetter):      header_message_limit: int +class CodeBlock(metaclass=YAMLGetter): +    section = 'code_block' + +    channel_whitelist: List[int] +    cooldown_channels: List[int] +    cooldown_seconds: int +    minimum_lines: int + +  class Free(metaclass=YAMLGetter):      section = 'free' @@ -578,6 +594,15 @@ class Verification(metaclass=YAMLGetter):      kick_confirmation_threshold: float +class VoiceGate(metaclass=YAMLGetter): +    section = "voice_gate" + +    minimum_days_verified: int +    minimum_messages: int +    bot_message_delete_delay: int +    minimum_activity_blocks: int + +  class Event(Enum):      """      Event names. This does not include every event (for example, raw @@ -618,6 +643,9 @@ STAFF_ROLES = Guild.staff_roles  # Channel combinations  MODERATION_CHANNELS = Guild.moderation_channels +# Category combinations +MODERATION_CATEGORIES = Guild.moderation_categories +  # Bot replies  NEGATIVE_REPLIES = [      "Noooooo!!", diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py index 7894ec48f..26f00e91f 100644 --- a/bot/exts/filters/antimalware.py +++ b/bot/exts/filters/antimalware.py @@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound  from discord.ext.commands import Cog  from bot.bot import Bot -from bot.constants import Channels, STAFF_ROLES, URLs +from bot.constants import Channels, Filter, URLs  log = logging.getLogger(__name__) @@ -61,7 +61,7 @@ class AntiMalware(Cog):          # Check if user is staff, if is, return          # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance -        if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): +        if hasattr(message.author, "roles") and any(role.id in Filter.role_whitelist for role in message.author.roles):              return          embed = Embed() diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index 4964283f1..af8528a68 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -15,7 +15,6 @@ from bot.constants import (      AntiSpam as AntiSpamConfig, Channels,      Colours, DEBUG_MODE, Event, Filter,      Guild as GuildConfig, Icons, -    STAFF_ROLES,  )  from bot.converters import Duration  from bot.exts.moderation.modlog import ModLog @@ -149,7 +148,7 @@ class AntiSpam(Cog):              or message.guild.id != GuildConfig.id              or message.author.bot              or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE) -            or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE) +            or (any(role.id in Filter.role_whitelist for role in message.author.roles) and not DEBUG_MODE)          ):              return diff --git a/bot/exts/filters/filtering.py b/bot/exts/filters/filtering.py index 92cdfb8f5..208fc9e1f 100644 --- a/bot/exts/filters/filtering.py +++ b/bot/exts/filters/filtering.py @@ -246,7 +246,7 @@ class Filtering(Cog):                              filter_triggered = True                          stats = self._add_stats(filter_name, match, result) -                        await self._send_log(filter_name, _filter["type"], msg, stats, is_eval=True) +                        await self._send_log(filter_name, _filter, msg, stats, is_eval=True)                          break  # We don't want multiple filters to trigger diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py index 82084ea88..48aa2749c 100644 --- a/bot/exts/fun/duck_pond.py +++ b/bot/exts/fun/duck_pond.py @@ -22,6 +22,7 @@ class DuckPond(Cog):          self.bot = bot          self.webhook_id = constants.Webhooks.duck_pond          self.webhook = None +        self.ducked_messages = []          self.bot.loop.create_task(self.fetch_webhook())          self.relay_lock = None @@ -176,7 +177,8 @@ class DuckPond(Cog):          duck_count = await self.count_ducks(message)          # If we've got more than the required amount of ducks, send the message to the duck_pond. -        if duck_count >= constants.DuckPond.threshold: +        if duck_count >= constants.DuckPond.threshold and message.id not in self.ducked_messages: +            self.ducked_messages.append(message.id)              await self.locked_relay(message)      @Cog.listener() diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index f5c9a5dd0..062d4fcfe 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -14,6 +14,7 @@ from discord.ext import commands  from bot import constants  from bot.bot import Bot +from bot.utils import channel as channel_utils  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) @@ -378,11 +379,18 @@ class HelpChannels(commands.Cog):          log.trace("Getting the CategoryChannel objects for the help categories.")          try: -            self.available_category = await self.try_get_channel( -                constants.Categories.help_available +            self.available_category = await channel_utils.try_get_channel( +                constants.Categories.help_available, +                self.bot +            ) +            self.in_use_category = await channel_utils.try_get_channel( +                constants.Categories.help_in_use, +                self.bot +            ) +            self.dormant_category = await channel_utils.try_get_channel( +                constants.Categories.help_dormant, +                self.bot              ) -            self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use) -            self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant)          except discord.HTTPException:              log.exception("Failed to get a category; cog will be removed")              self.bot.remove_cog(self.qualified_name) @@ -442,12 +450,6 @@ class HelpChannels(commands.Cog):              return False          return message.author == self.bot.user and bot_msg_desc.strip() == description.strip() -    @staticmethod -    def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: -        """Return True if `channel` is within a category with `category_id`.""" -        actual_category = getattr(channel, "category", None) -        return actual_category is not None and actual_category.id == category_id -      async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None:          """          Make the `channel` dormant if idle or schedule the move if still active. @@ -498,7 +500,7 @@ class HelpChannels(commands.Cog):          options should be avoided, as it may interfere with the category move we perform.          """          # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. -        category = await self.try_get_channel(category_id) +        category = await channel_utils.try_get_channel(category_id, self.bot)          payload = [{"id": c.id, "position": c.position} for c in category.channels] @@ -646,7 +648,7 @@ class HelpChannels(commands.Cog):          channel = message.channel          # Confirm the channel is an in use help channel -        if self.is_in_category(channel, constants.Categories.help_in_use): +        if channel_utils.is_in_category(channel, constants.Categories.help_in_use):              log.trace(f"Checking if #{channel} ({channel.id}) has been answered.")              # Check if there is an entry in unanswered @@ -671,7 +673,8 @@ class HelpChannels(commands.Cog):          await self.check_for_answer(message) -        if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel): +        is_available = channel_utils.is_in_category(channel, constants.Categories.help_available) +        if not is_available or self.is_excluded_channel(channel):              return  # Ignore messages outside the Available category or in excluded channels.          log.trace("Waiting for the cog to be ready before processing messages.") @@ -681,7 +684,7 @@ class HelpChannels(commands.Cog):          async with self.on_message_lock:              log.trace(f"on_message lock acquired for {message.id}.") -            if not self.is_in_category(channel, constants.Categories.help_available): +            if not channel_utils.is_in_category(channel, constants.Categories.help_available):                  log.debug(                      f"Message {message.id} will not make #{channel} ({channel.id}) in-use "                      f"because another message in the channel already triggered that." @@ -719,7 +722,7 @@ class HelpChannels(commands.Cog):          The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`.          """ -        if not self.is_in_category(msg.channel, constants.Categories.help_in_use): +        if not channel_utils.is_in_category(msg.channel, constants.Categories.help_in_use):              return          if not await self.is_empty(msg.channel): @@ -844,18 +847,6 @@ class HelpChannels(commands.Cog):              log.trace(f"Dormant message not found in {channel_info}; sending a new message.")              await channel.send(embed=embed) -    async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel: -        """Attempt to get or fetch a channel and return it.""" -        log.trace(f"Getting the channel {channel_id}.") - -        channel = self.bot.get_channel(channel_id) -        if not channel: -            log.debug(f"Channel {channel_id} is not in cache; fetching from API.") -            channel = await self.bot.fetch_channel(channel_id) - -        log.trace(f"Channel #{channel} ({channel_id}) retrieved.") -        return channel -      async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool:          """          Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False. diff --git a/bot/exts/info/codeblock/__init__.py b/bot/exts/info/codeblock/__init__.py new file mode 100644 index 000000000..5c55bc5e3 --- /dev/null +++ b/bot/exts/info/codeblock/__init__.py @@ -0,0 +1,8 @@ +from bot.bot import Bot + + +def setup(bot: Bot) -> None: +    """Load the CodeBlockCog cog.""" +    # Defer import to reduce side effects from importing the codeblock package. +    from bot.exts.info.codeblock._cog import CodeBlockCog +    bot.add_cog(CodeBlockCog(bot)) diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py new file mode 100644 index 000000000..1e0feab0d --- /dev/null +++ b/bot/exts/info/codeblock/_cog.py @@ -0,0 +1,186 @@ +import logging +import time +from typing import Optional + +import discord +from discord import Message, RawMessageUpdateEvent +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.exts.filters.token_remover import TokenRemover +from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE +from bot.exts.info.codeblock._instructions import get_instructions +from bot.utils import has_lines +from bot.utils.channel import is_help_channel +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + + +class CodeBlockCog(Cog, name="Code Block"): +    """ +    Detect improperly formatted Markdown code blocks and suggest proper formatting. + +    There are four basic ways in which a code block is considered improperly formatted: + +    1. The code is not within a code block at all +        * Ignored if the code is not valid Python or Python REPL code +    2. Incorrect characters are used for backticks +    3. A language for syntax highlighting is not specified +        * Ignored if the code is not valid Python or Python REPL code +    4. A syntax highlighting language is incorrectly specified +        * Ignored if the language specified doesn't look like it was meant for Python +        * This can go wrong in two ways: +            1. Spaces before the language +            2. No newline immediately following the language + +    Messages or code blocks must meet a minimum line count to be detected. Detecting multiple code +    blocks is supported. However, if at least one code block is correct, then instructions will not +    be sent even if others are incorrect. When multiple incorrect code blocks are found, only the +    first one is used as the basis for the instructions sent. + +    When an issue is detected, an embed is sent containing specific instructions on fixing what +    is wrong. If the user edits their message to fix the code block, the instructions will be +    removed. If they fail to fix the code block with an edit, the instructions will be updated to +    show what is still incorrect after the user's edit. The embed can be manually deleted with a +    reaction. Otherwise, it will automatically be removed after 5 minutes. + +    The cog only detects messages in whitelisted channels. Channels may also have a cooldown on the +    instructions being sent. Note all help channels are also whitelisted with cooldowns enabled. + +    For configurable parameters, see the `code_block` section in config-default.py. +    """ + +    def __init__(self, bot: Bot): +        self.bot = bot + +        # Stores allowed channels plus epoch times since the last instructional messages sent. +        self.channel_cooldowns = dict.fromkeys(constants.CodeBlock.cooldown_channels, 0.0) + +        # Maps users' messages to the messages the bot sent with instructions. +        self.codeblock_message_ids = {} + +    @staticmethod +    def create_embed(instructions: str) -> discord.Embed: +        """Return an embed which displays code block formatting `instructions`.""" +        return discord.Embed(description=instructions) + +    async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]: +        """ +        Return the bot's sent instructions message associated with a user's message `payload`. + +        Return None if the message cannot be found. In this case, it's likely the message was +        deleted either manually via a reaction or automatically by a timer. +        """ +        log.trace(f"Retrieving instructions message for ID {payload.message_id}") +        channel = self.bot.get_channel(payload.channel_id) + +        try: +            return await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) +        except discord.NotFound: +            log.debug("Could not find instructions message; it was probably deleted.") +            return None + +    def is_on_cooldown(self, channel: discord.TextChannel) -> bool: +        """ +        Return True if an embed was sent too recently for `channel`. + +        The cooldown is configured by `constants.CodeBlock.cooldown_seconds`. +        Note: only channels in the `channel_cooldowns` have cooldowns enabled. +        """ +        log.trace(f"Checking if #{channel} is on cooldown.") +        cooldown = constants.CodeBlock.cooldown_seconds +        return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown + +    def is_valid_channel(self, channel: discord.TextChannel) -> bool: +        """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted.""" +        log.trace(f"Checking if #{channel} qualifies for code block detection.") +        return ( +            is_help_channel(channel) +            or channel.id in self.channel_cooldowns +            or channel.id in constants.CodeBlock.channel_whitelist +        ) + +    async def send_instructions(self, message: discord.Message, instructions: str) -> None: +        """ +        Send an embed with `instructions` on fixing an incorrect code block in a `message`. + +        The embed will be deleted automatically after 5 minutes. +        """ +        log.info(f"Sending code block formatting instructions for message {message.id}.") + +        embed = self.create_embed(instructions) +        bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed) +        self.codeblock_message_ids[message.id] = bot_message.id + +        self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot)) + +        # Increase amount of codeblock correction in stats +        self.bot.stats.incr("codeblock_corrections") + +    def should_parse(self, message: discord.Message) -> bool: +        """ +        Return True if `message` should be parsed. + +        A qualifying message: + +        1. Is not authored by a bot +        2. Is in a valid channel +        3. Has more than 3 lines +        4. Has no bot or webhook token +        """ +        return ( +            not message.author.bot +            and self.is_valid_channel(message.channel) +            and has_lines(message.content, constants.CodeBlock.minimum_lines) +            and not TokenRemover.find_token_in_message(message) +            and not WEBHOOK_URL_RE.search(message.content) +        ) + +    @Cog.listener() +    async def on_message(self, msg: Message) -> None: +        """Detect incorrect Markdown code blocks in `msg` and send instructions to fix them.""" +        if not self.should_parse(msg): +            log.trace(f"Skipping code block detection of {msg.id}: message doesn't qualify.") +            return + +        # When debugging, ignore cooldowns. +        if self.is_on_cooldown(msg.channel) and not constants.DEBUG_MODE: +            log.trace(f"Skipping code block detection of {msg.id}: #{msg.channel} is on cooldown.") +            return + +        instructions = get_instructions(msg.content) +        if instructions: +            await self.send_instructions(msg, instructions) + +            if msg.channel.id not in constants.CodeBlock.channel_whitelist: +                log.debug(f"Adding #{msg.channel} to the channel cooldowns.") +                self.channel_cooldowns[msg.channel.id] = time.time() + +    @Cog.listener() +    async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: +        """Delete the instructional message if an edited message had its code blocks fixed.""" +        if payload.message_id not in self.codeblock_message_ids: +            log.trace(f"Ignoring message edit {payload.message_id}: message isn't being tracked.") +            return + +        if payload.data.get("content") is None or payload.data.get("channel_id") is None: +            log.trace(f"Ignoring message edit {payload.message_id}: missing content or channel ID.") +            return + +        # Parse the message to see if the code blocks have been fixed. +        content = payload.data.get("content") +        instructions = get_instructions(content) + +        bot_message = await self.get_sent_instructions(payload) +        if not bot_message: +            return + +        if not instructions: +            log.info("User's incorrect code block has been fixed. Removing instructions message.") +            await bot_message.delete() +            del self.codeblock_message_ids[payload.message_id] +        else: +            log.info("Message edited but still has invalid code blocks; editing the instructions.") +            await bot_message.edit(embed=self.create_embed(instructions)) diff --git a/bot/exts/info/codeblock/_instructions.py b/bot/exts/info/codeblock/_instructions.py new file mode 100644 index 000000000..dadb5e1ef --- /dev/null +++ b/bot/exts/info/codeblock/_instructions.py @@ -0,0 +1,184 @@ +"""This module generates and formats instructional messages about fixing Markdown code blocks.""" + +import logging +from typing import Optional + +from bot.exts.info.codeblock import _parsing + +log = logging.getLogger(__name__) + +_EXAMPLE_PY = "{lang}\nprint('Hello, world!')"  # Make sure to escape any Markdown symbols here. +_EXAMPLE_CODE_BLOCKS = ( +    "\\`\\`\\`{content}\n\\`\\`\\`\n\n" +    "**This will result in the following:**\n" +    "```{content}```" +) + + +def _get_example(language: str) -> str: +    """Return an example of a correct code block using `language` for syntax highlighting.""" +    # Determine the example code to put in the code block based on the language specifier. +    if language.lower() in _parsing.PY_LANG_CODES: +        log.trace(f"Code block has a Python language specifier `{language}`.") +        content = _EXAMPLE_PY.format(lang=language) +    elif language: +        log.trace(f"Code block has a foreign language specifier `{language}`.") +        # It's not feasible to determine what would be a valid example for other languages. +        content = f"{language}\n..." +    else: +        log.trace("Code block has no language specifier.") +        content = "\nHello, world!" + +    return _EXAMPLE_CODE_BLOCKS.format(content=content) + + +def _get_bad_ticks_message(code_block: _parsing.CodeBlock) -> Optional[str]: +    """Return instructions on using the correct ticks for `code_block`.""" +    log.trace("Creating instructions for incorrect code block ticks.") + +    valid_ticks = f"\\{_parsing.BACKTICK}" * 3 +    instructions = ( +        "It looks like you are trying to paste code into this channel.\n\n" +        "You seem to be using the wrong symbols to indicate where the code block should start. " +        f"The correct symbols would be {valid_ticks}, not `{code_block.tick * 3}`." +    ) + +    log.trace("Check if the bad ticks code block also has issues with the language specifier.") +    addition_msg = _get_bad_lang_message(code_block.content) +    if not addition_msg and not code_block.language: +        addition_msg = _get_no_lang_message(code_block.content) + +    # Combine the back ticks message with the language specifier message. The latter will +    # already have an example code block. +    if addition_msg: +        log.trace("Language specifier issue found; appending additional instructions.") + +        # The first line has double newlines which are not desirable when appending the msg. +        addition_msg = addition_msg.replace("\n\n", " ", 1) + +        # Make the first character of the addition lower case. +        instructions += "\n\nFurthermore, " + addition_msg[0].lower() + addition_msg[1:] +    else: +        log.trace("No issues with the language specifier found.") +        example_blocks = _get_example(code_block.language) +        instructions += f"\n\n**Here is an example of how it should look:**\n{example_blocks}" + +    return instructions + + +def _get_no_ticks_message(content: str) -> Optional[str]: +    """If `content` is Python/REPL code, return instructions on using code blocks.""" +    log.trace("Creating instructions for a missing code block.") + +    if _parsing.is_python_code(content): +        example_blocks = _get_example("py") +        return ( +            "It looks like you're trying to paste code into this channel.\n\n" +            "Discord has support for Markdown, which allows you to post code with full " +            "syntax highlighting. Please use these whenever you paste code, as this " +            "helps improve the legibility and makes it easier for us to help you.\n\n" +            f"**To do this, use the following method:**\n{example_blocks}" +        ) +    else: +        log.trace("Aborting missing code block instructions: content is not Python code.") + + +def _get_bad_lang_message(content: str) -> Optional[str]: +    """ +    Return instructions on fixing the Python language specifier for a code block. + +    If `code_block` does not have a Python language specifier, return None. +    If there's nothing wrong with the language specifier, return None. +    """ +    log.trace("Creating instructions for a poorly specified language.") + +    info = _parsing.parse_bad_language(content) +    if not info: +        log.trace("Aborting bad language instructions: language specified isn't Python.") +        return + +    lines = [] +    language = info.language + +    if info.has_leading_spaces: +        log.trace("Language specifier was preceded by a space.") +        lines.append(f"Make sure there are no spaces between the back ticks and `{language}`.") + +    if not info.has_terminal_newline: +        log.trace("Language specifier was not followed by a newline.") +        lines.append( +            f"Make sure you put your code on a new line following `{language}`. " +            f"There must not be any spaces after `{language}`." +        ) + +    if lines: +        lines = " ".join(lines) +        example_blocks = _get_example(language) + +        # Note that _get_bad_ticks_message expects the first line to have two newlines. +        return ( +            f"It looks like you incorrectly specified a language for your code block.\n\n{lines}" +            f"\n\n**Here is an example of how it should look:**\n{example_blocks}" +        ) +    else: +        log.trace("Nothing wrong with the language specifier; no instructions to return.") + + +def _get_no_lang_message(content: str) -> Optional[str]: +    """ +    Return instructions on specifying a language for a code block. + +    If `content` is not valid Python or Python REPL code, return None. +    """ +    log.trace("Creating instructions for a missing language.") + +    if _parsing.is_python_code(content): +        example_blocks = _get_example("py") + +        # Note that _get_bad_ticks_message expects the first line to have two newlines. +        return ( +            "It looks like you pasted Python code without syntax highlighting.\n\n" +            "Please use syntax highlighting to improve the legibility of your code and make " +            "it easier for us to help you.\n\n" +            f"**To do this, use the following method:**\n{example_blocks}" +        ) +    else: +        log.trace("Aborting missing language instructions: content is not Python code.") + + +def get_instructions(content: str) -> Optional[str]: +    """ +    Parse `content` and return code block formatting instructions if something is wrong. + +    Return None if `content` lacks code block formatting issues. +    """ +    log.trace("Getting formatting instructions.") + +    blocks = _parsing.find_code_blocks(content) +    if blocks is None: +        log.trace("At least one valid code block found; no instructions to return.") +        return + +    if not blocks: +        log.trace("No code blocks were found in message.") +        instructions = _get_no_ticks_message(content) +    else: +        log.trace("Searching results for a code block with invalid ticks.") +        block = next((block for block in blocks if block.tick != _parsing.BACKTICK), None) + +        if block: +            log.trace("A code block exists but has invalid ticks.") +            instructions = _get_bad_ticks_message(block) +        else: +            log.trace("A code block exists but is missing a language.") +            block = blocks[0] + +            # Check for a bad language first to avoid parsing content into an AST. +            instructions = _get_bad_lang_message(block.content) +            if not instructions: +                instructions = _get_no_lang_message(block.content) + +    if instructions: +        instructions += "\nYou can **edit your original message** to correct your code block." + +    return instructions diff --git a/bot/exts/info/codeblock/_parsing.py b/bot/exts/info/codeblock/_parsing.py new file mode 100644 index 000000000..65a2272c8 --- /dev/null +++ b/bot/exts/info/codeblock/_parsing.py @@ -0,0 +1,228 @@ +"""This module provides functions for parsing Markdown code blocks.""" + +import ast +import logging +import re +import textwrap +from typing import NamedTuple, Optional, Sequence + +from bot import constants +from bot.utils import has_lines + +log = logging.getLogger(__name__) + +BACKTICK = "`" +PY_LANG_CODES = ("python-repl", "python", "pycon", "py")  # Order is important; "py" is last cause it's a subset. +_TICKS = { +    BACKTICK, +    "'", +    '"', +    "\u00b4",  # ACUTE ACCENT +    "\u2018",  # LEFT SINGLE QUOTATION MARK +    "\u2019",  # RIGHT SINGLE QUOTATION MARK +    "\u2032",  # PRIME +    "\u201c",  # LEFT DOUBLE QUOTATION MARK +    "\u201d",  # RIGHT DOUBLE QUOTATION MARK +    "\u2033",  # DOUBLE PRIME +    "\u3003",  # VERTICAL KANA REPEAT MARK UPPER HALF +} + +_RE_PYTHON_REPL = re.compile(r"^(>>>|\.\.\.)( |$)") +_RE_IPYTHON_REPL = re.compile(r"^((In|Out) \[\d+\]: |\s*\.{3,}: ?)") + +_RE_CODE_BLOCK = re.compile( +    fr""" +    (?P<ticks> +        (?P<tick>[{''.join(_TICKS)}]) # Put all ticks into a character class within a group. +        \2{{2}}                       # Match previous group 2 more times to ensure the same char. +    ) +    (?P<lang>[^\W_]+\n)?              # Optionally match a language specifier followed by a newline. +    (?P<code>.+?)                     # Match the actual code within the block. +    \1                                # Match the same 3 ticks used at the start of the block. +    """, +    re.DOTALL | re.VERBOSE +) + +_RE_LANGUAGE = re.compile( +    fr""" +    ^(?P<spaces>\s+)?                    # Optionally match leading spaces from the beginning. +    (?P<lang>{'|'.join(PY_LANG_CODES)})  # Match a Python language. +    (?P<newline>\n)?                     # Optionally match a newline following the language. +    """, +    re.IGNORECASE | re.VERBOSE +) + + +class CodeBlock(NamedTuple): +    """Represents a Markdown code block.""" + +    content: str +    language: str +    tick: str + + +class BadLanguage(NamedTuple): +    """Parsed information about a poorly formatted language specifier.""" + +    language: str +    has_leading_spaces: bool +    has_terminal_newline: bool + + +def find_code_blocks(message: str) -> Optional[Sequence[CodeBlock]]: +    """ +    Find and return all Markdown code blocks in the `message`. + +    Code blocks with 3 or fewer lines are excluded. + +    If the `message` contains at least one code block with valid ticks and a specified language, +    return None. This is based on the assumption that if the user managed to get one code block +    right, they already know how to fix the rest themselves. +    """ +    log.trace("Finding all code blocks in a message.") + +    code_blocks = [] +    for match in _RE_CODE_BLOCK.finditer(message): +        # Used to ensure non-matched groups have an empty string as the default value. +        groups = match.groupdict("") +        language = groups["lang"].strip()  # Strip the newline cause it's included in the group. + +        if groups["tick"] == BACKTICK and language: +            log.trace("Message has a valid code block with a language; returning None.") +            return None +        elif has_lines(groups["code"], constants.CodeBlock.minimum_lines): +            code_block = CodeBlock(groups["code"], language, groups["tick"]) +            code_blocks.append(code_block) +        else: +            log.trace("Skipped a code block shorter than 4 lines.") + +    return code_blocks + + +def _is_python_code(content: str) -> bool: +    """Return True if `content` is valid Python consisting of more than just expressions.""" +    log.trace("Checking if content is Python code.") +    try: +        # Attempt to parse the message into an AST node. +        # Invalid Python code will raise a SyntaxError. +        tree = ast.parse(content) +    except SyntaxError: +        log.trace("Code is not valid Python.") +        return False + +    # Multiple lines of single words could be interpreted as expressions. +    # This check is to avoid all nodes being parsed as expressions. +    # (e.g. words over multiple lines) +    if not all(isinstance(node, ast.Expr) for node in tree.body): +        log.trace("Code is valid python.") +        return True +    else: +        log.trace("Code consists only of expressions.") +        return False + + +def _is_repl_code(content: str, threshold: int = 3) -> bool: +    """Return True if `content` has at least `threshold` number of (I)Python REPL-like lines.""" +    log.trace(f"Checking if content is (I)Python REPL code using a threshold of {threshold}.") + +    repl_lines = 0 +    patterns = (_RE_PYTHON_REPL, _RE_IPYTHON_REPL) + +    for line in content.splitlines(): +        # Check the line against all patterns. +        for pattern in patterns: +            if pattern.match(line): +                repl_lines += 1 + +                # Once a pattern is matched, only use that pattern for the remaining lines. +                patterns = (pattern,) +                break + +        if repl_lines == threshold: +            log.trace("Content is (I)Python REPL code.") +            return True + +    log.trace("Content is not (I)Python REPL code.") +    return False + + +def is_python_code(content: str) -> bool: +    """Return True if `content` is valid Python code or (I)Python REPL output.""" +    dedented = textwrap.dedent(content) + +    # Parse AST twice in case _fix_indentation ends up breaking code due to its inaccuracies. +    return ( +        _is_python_code(dedented) +        or _is_repl_code(dedented) +        or _is_python_code(_fix_indentation(content)) +    ) + + +def parse_bad_language(content: str) -> Optional[BadLanguage]: +    """ +    Return information about a poorly formatted Python language in code block `content`. + +    If the language is not Python, return None. +    """ +    log.trace("Parsing bad language.") + +    match = _RE_LANGUAGE.match(content) +    if not match: +        return None + +    return BadLanguage( +        language=match["lang"], +        has_leading_spaces=match["spaces"] is not None, +        has_terminal_newline=match["newline"] is not None, +    ) + + +def _get_leading_spaces(content: str) -> int: +    """Return the number of spaces at the start of the first line in `content`.""" +    leading_spaces = 0 +    for char in content: +        if char == " ": +            leading_spaces += 1 +        else: +            return leading_spaces + + +def _fix_indentation(content: str) -> str: +    """ +    Attempt to fix badly indented code in `content`. + +    In most cases, this works like textwrap.dedent. However, if the first line ends with a colon, +    all subsequent lines are re-indented to only be one level deep relative to the first line. +    The intent is to fix cases where the leading spaces of the first line of code were accidentally +    not copied, which makes the first line appear not indented. + +    This is fairly naïve and inaccurate. Therefore, it may break some code that was otherwise valid. +    It's meant to catch really common cases, so that's acceptable. Its flaws are: + +    - It assumes that if the first line ends with a colon, it is the start of an indented block +    - It uses 4 spaces as the indentation, regardless of what the rest of the code uses +    """ +    lines = content.splitlines(keepends=True) + +    # Dedent the first line +    first_indent = _get_leading_spaces(content) +    first_line = lines[0][first_indent:] + +    # Can't assume there'll be multiple lines cause line counts of edited messages aren't checked. +    if len(lines) == 1: +        return first_line + +    second_indent = _get_leading_spaces(lines[1]) + +    # If the first line ends with a colon, all successive lines need to be indented one +    # additional level (assumes an indent width of 4). +    if first_line.rstrip().endswith(":"): +        second_indent -= 4 + +    # All lines must be dedented at least by the same amount as the first line. +    first_indent = max(first_indent, second_indent) + +    # Dedent the rest of the lines and join them together with the first line. +    content = first_line + "".join(line[first_indent:] for line in lines[1:]) + +    return content diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py index c16a99225..7ec8caa4b 100644 --- a/bot/exts/info/doc.py +++ b/bot/exts/info/doc.py @@ -3,10 +3,9 @@ import functools  import logging  import re  import textwrap -from collections import OrderedDict  from contextlib import suppress  from types import SimpleNamespace -from typing import Any, Callable, Optional, Tuple +from typing import Optional, Tuple  import discord  from bs4 import BeautifulSoup @@ -22,6 +21,7 @@ from bot.bot import Bot  from bot.constants import MODERATION_ROLES, RedirectOutput  from bot.converters import ValidPythonIdentifier, ValidURL  from bot.pagination import LinePaginator +from bot.utils.cache import AsyncCache  from bot.utils.messages import wait_for_deletion @@ -65,34 +65,7 @@ WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)")  FAILED_REQUEST_RETRY_AMOUNT = 3  NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - -def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: -    """ -    LRU cache implementation for coroutines. - -    Once the cache exceeds the maximum size, keys are deleted in FIFO order. - -    An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. -    """ -    # Assign the cache to the function itself so we can clear it from outside. -    async_cache.cache = OrderedDict() - -    def decorator(function: Callable) -> Callable: -        """Define the async_cache decorator.""" -        @functools.wraps(function) -        async def wrapper(*args) -> Any: -            """Decorator wrapper for the caching logic.""" -            key = ':'.join(args[arg_offset:]) - -            value = async_cache.cache.get(key) -            if value is None: -                if len(async_cache.cache) > max_size: -                    async_cache.cache.popitem(last=False) - -                async_cache.cache[key] = await function(*args) -            return async_cache.cache[key] -        return wrapper -    return decorator +symbol_cache = AsyncCache()  class DocMarkdownConverter(MarkdownConverter): @@ -215,7 +188,7 @@ class Doc(commands.Cog):          self.base_urls.clear()          self.inventories.clear()          self.renamed_symbols.clear() -        async_cache.cache = OrderedDict() +        symbol_cache.clear()          # Run all coroutines concurrently - since each of them performs a HTTP          # request, this speeds up fetching the inventory data heavily. @@ -280,7 +253,7 @@ class Doc(commands.Cog):          return signatures, description.replace('¶', '') -    @async_cache(arg_offset=1) +    @symbol_cache(arg_offset=1)      async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]:          """          Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 0f50138e7..5aaf85e5a 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -6,14 +6,16 @@ from collections import Counter, defaultdict  from string import Template  from typing import Any, Mapping, Optional, Tuple, Union -from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils +from discord import ChannelType, Colour, Embed, Guild, Message, Role, Status, utils  from discord.abc import GuildChannel  from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role  from bot import constants  from bot.bot import Bot +from bot.converters import FetchedMember  from bot.decorators import in_whitelist  from bot.pagination import LinePaginator +from bot.utils.channel import is_mod_channel  from bot.utils.checks import cooldown_with_role_bypass, has_no_roles_check, in_whitelist_check  from bot.utils.time import time_since @@ -191,7 +193,7 @@ class Information(Cog):          await ctx.send(embed=embed)      @command(name="user", aliases=["user_info", "member", "member_info"]) -    async def user_info(self, ctx: Context, user: Member = None) -> None: +    async def user_info(self, ctx: Context, user: FetchedMember = None) -> None:          """Returns info about a user."""          if user is None:              user = ctx.author @@ -206,12 +208,14 @@ class Information(Cog):              embed = await self.create_user_embed(ctx, user)              await ctx.send(embed=embed) -    async def create_user_embed(self, ctx: Context, user: Member) -> Embed: +    async def create_user_embed(self, ctx: Context, user: FetchedMember) -> Embed:          """Creates an embed containing information on the `user`.""" +        on_server = bool(ctx.guild.get_member(user.id)) +          created = time_since(user.created_at, max_units=3)          name = str(user) -        if user.nick: +        if on_server and user.nick:              name = f"{user.nick} ({name})"          badges = [] @@ -220,8 +224,16 @@ class Information(Cog):              if is_set and (emoji := getattr(constants.Emojis, f"badge_{badge}", None)):                  badges.append(emoji) -        joined = time_since(user.joined_at, max_units=3) -        roles = ", ".join(role.mention for role in user.roles[1:]) +        if on_server: +            joined = time_since(user.joined_at, max_units=3) +            roles = ", ".join(role.mention for role in user.roles[1:]) +            membership = textwrap.dedent(f""" +                             Joined: {joined} +                             Roles: {roles or None} +                         """).strip() +        else: +            roles = None +            membership = "The user is not a member of the server"          fields = [              ( @@ -234,21 +246,12 @@ class Information(Cog):              ),              (                  "Member information", -                textwrap.dedent(f""" -                    Joined: {joined} -                    Roles: {roles or None} -                """).strip() +                membership              ),          ] -        # Use getattr to future-proof for commands invoked via DMs. -        show_verbose = ( -            ctx.channel.id in constants.MODERATION_CHANNELS -            or getattr(ctx.channel, "category_id", None) == constants.Categories.modmail -        ) -          # Show more verbose output in moderation channels for infractions and nominations -        if show_verbose: +        if is_mod_channel(ctx.channel):              fields.append(await self.expanded_user_infraction_counts(user))              fields.append(await self.user_nomination_counts(user))          else: @@ -268,13 +271,13 @@ class Information(Cog):          return embed -    async def basic_user_infraction_counts(self, member: Member) -> Tuple[str, str]: +    async def basic_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:          """Gets the total and active infraction counts for the given `member`."""          infractions = await self.bot.api_client.get(              'bot/infractions',              params={                  'hidden': 'False', -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) @@ -285,7 +288,7 @@ class Information(Cog):          return "Infractions", infraction_output -    async def expanded_user_infraction_counts(self, member: Member) -> Tuple[str, str]: +    async def expanded_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:          """          Gets expanded infraction counts for the given `member`. @@ -295,7 +298,7 @@ class Information(Cog):          infractions = await self.bot.api_client.get(              'bot/infractions',              params={ -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) @@ -326,12 +329,12 @@ class Information(Cog):          return "Infractions", "\n".join(infraction_output) -    async def user_nomination_counts(self, member: Member) -> Tuple[str, str]: +    async def user_nomination_counts(self, user: FetchedMember) -> Tuple[str, str]:          """Gets the active and historical nomination counts for the given `member`."""          nominations = await self.bot.api_client.get(              'bot/nominations',              params={ -                'user__id': str(member.id) +                'user__id': str(user.id)              }          ) diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py index 9b9c0028f..6790be762 100644 --- a/bot/exts/info/reddit.py +++ b/bot/exts/info/reddit.py @@ -140,7 +140,10 @@ class Reddit(Cog):                  # Got appropriate response - process and return.                  content = await response.json()                  posts = content["data"]["children"] -                return posts[:amount] + +                filtered_posts = [post for post in posts if not post["data"]["over_18"]] + +                return filtered_posts[:amount]              await asyncio.sleep(3) @@ -163,12 +166,11 @@ class Reddit(Cog):              amount=amount,              params={"t": time}          ) -          if not posts:              embed.title = random.choice(ERROR_REPLIES)              embed.colour = Colour.red()              embed.description = ( -                "Sorry! We couldn't find any posts from that subreddit. " +                "Sorry! We couldn't find any SFW posts from that subreddit. "                  "If this problem persists, please let us know."              ) diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py index 21aa91873..4d8bb645e 100644 --- a/bot/exts/info/stats.py +++ b/bot/exts/info/stats.py @@ -6,7 +6,7 @@ from discord.ext.tasks import loop  from bot.bot import Bot  from bot.constants import Categories, Channels, Guild - +from bot.utils.channel import is_in_category  CHANNEL_NAME_OVERRIDES = {      Channels.off_topic_0: "off_topic_0", @@ -35,8 +35,7 @@ class Stats(Cog):          if message.guild.id != Guild.id:              return -        cat = getattr(message.channel, "category", None) -        if cat is not None and cat.id == Categories.modmail: +        if is_in_category(message.channel, Categories.modmail):              if message.channel.id != Channels.incidents:                  # Do not report modmail channels to stats, there are too many                  # of them for interesting statistics to be drawn out of this. diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py index 99bb1ae11..ed67e3b26 100644 --- a/bot/exts/moderation/infraction/_scheduler.py +++ b/bot/exts/moderation/infraction/_scheduler.py @@ -12,11 +12,12 @@ from discord.ext.commands import Context  from bot import constants  from bot.api import ResponseCodeError  from bot.bot import Bot -from bot.constants import Colours, MODERATION_CHANNELS +from bot.constants import Colours  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._utils import UserSnowflake  from bot.exts.moderation.modlog import ModLog  from bot.utils import messages, scheduling, time +from bot.utils.channel import is_mod_channel  log = logging.getLogger(__name__) @@ -130,7 +131,7 @@ class InfractionScheduler:                  log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})")              else:                  # Accordingly display whether the user was successfully notified via DM. -                if await _utils.notify_infraction(user, infr_type, expiry, reason, icon): +                if await _utils.notify_infraction(user, " ".join(infr_type.split("_")).title(), expiry, reason, icon):                      dm_result = ":incoming_envelope: "                      dm_log_text = "\nDM: Sent" @@ -141,11 +142,7 @@ class InfractionScheduler:              )              if reason:                  end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})" -        elif ctx.channel.id not in MODERATION_CHANNELS: -            log.trace( -                f"Infraction #{id_} context is not in a mod channel; omitting infraction count." -            ) -        else: +        elif is_mod_channel(ctx.channel):              log.trace(f"Fetching total infraction count for {user}.")              infractions = await self.bot.api_client.get( @@ -153,7 +150,7 @@ class InfractionScheduler:                  params={"user__id": str(user.id)}              )              total = len(infractions) -            end_msg = f" ({total} infraction{ngettext('', 's', total)} total)" +            end_msg = f" (#{id_} ; {total} infraction{ngettext('', 's', total)} total)"          # Execute the necessary actions to apply the infraction on Discord.          if action_coro: @@ -171,7 +168,7 @@ class InfractionScheduler:                  log_content = ctx.author.mention                  log_title = "failed to apply" -                log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" +                log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}"                  if isinstance(e, discord.Forbidden):                      log.warning(f"{log_msg}: bot lacks permissions.")                  else: @@ -188,7 +185,7 @@ class InfractionScheduler:                  log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.")              infr_message = ""          else: -            infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}" +            infr_message = f" **{' '.join(infr_type.split('_'))}** to {user.mention}{expiry_msg}{end_msg}"          # Send a confirmation message to the invoking context.          log.trace(f"Sending infraction #{id_} confirmation message.") @@ -200,7 +197,7 @@ class InfractionScheduler:          await self.mod_log.send_log_message(              icon_url=icon,              colour=Colours.soft_red, -            title=f"Infraction {log_title}: {infr_type}", +            title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",              thumbnail=user.avatar_url_as(static_format="png"),              text=textwrap.dedent(f"""                  Member: {messages.format_user(user)} @@ -277,7 +274,7 @@ class InfractionScheduler:          if send_msg:              log.trace(f"Sending infraction #{id_} pardon confirmation message.")              await ctx.send( -                f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " +                f"{dm_emoji}{confirm_msg} infraction **{' '.join(infr_type.split('_'))}** for {user.mention}. "                  f"{log_text.get('Failure', '')}"              ) @@ -288,7 +285,7 @@ class InfractionScheduler:          await self.mod_log.send_log_message(              icon_url=_utils.INFRACTION_ICONS[infr_type][1],              colour=Colours.soft_green, -            title=f"Infraction {log_title}: {infr_type}", +            title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",              thumbnail=user.avatar_url_as(static_format="png"),              text="\n".join(f"{k}: {v}" for k, v in log_text.items()),              footer=footer, diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py index 1d91964f1..d0dc3f0a1 100644 --- a/bot/exts/moderation/infraction/_utils.py +++ b/bot/exts/moderation/infraction/_utils.py @@ -18,9 +18,10 @@ INFRACTION_ICONS = {      "note": (Icons.user_warn, None),      "superstar": (Icons.superstarify, Icons.unsuperstarify),      "warning": (Icons.user_warn, None), +    "voice_ban": (Icons.voice_state_red, Icons.voice_state_green),  }  RULES_URL = "https://pythondiscord.com/pages/rules" -APPEALABLE_INFRACTIONS = ("ban", "mute") +APPEALABLE_INFRACTIONS = ("ban", "mute", "voice_ban")  # Type aliases  UserObject = t.Union[discord.Member, discord.User] @@ -154,7 +155,7 @@ async def notify_infraction(      log.trace(f"Sending {user} a DM about their {infr_type} infraction.")      text = INFRACTION_DESCRIPTION_TEMPLATE.format( -        type=infr_type.capitalize(), +        type=infr_type.title(),          expires=expires_at or "N/A",          reason=reason or "No reason provided."      ) diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index b638f4dc6..8abb199db 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -31,6 +31,7 @@ class Infractions(InfractionScheduler, commands.Cog):          self.category = "Moderation"          self._muted_role = discord.Object(constants.Roles.muted) +        self._voice_verified_role = discord.Object(constants.Roles.voice_verified)      @commands.Cog.listener()      async def on_member_join(self, member: Member) -> None: @@ -88,6 +89,11 @@ class Infractions(InfractionScheduler, commands.Cog):          """          await self.apply_ban(ctx, user, reason, max(min(purge_days, 7), 0)) +    @command(aliases=('vban',)) +    async def voiceban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str]) -> None: +        """Permanently ban user from using voice channels.""" +        await self.apply_voice_ban(ctx, user, reason) +      # endregion      # region: Temporary infractions @@ -136,6 +142,32 @@ class Infractions(InfractionScheduler, commands.Cog):          """          await self.apply_ban(ctx, user, reason, expires_at=duration) +    @command(aliases=("tempvban", "tvban")) +    async def tempvoiceban( +            self, +            ctx: Context, +            user: FetchedMember, +            duration: Expiry, +            *, +            reason: t.Optional[str] +    ) -> None: +        """ +        Temporarily voice ban a user for the given reason and duration. + +        A unit of time should be appended to the duration. +        Units (∗case-sensitive): +        \u2003`y` - years +        \u2003`m` - months∗ +        \u2003`w` - weeks +        \u2003`d` - days +        \u2003`h` - hours +        \u2003`M` - minutes∗ +        \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration. +        """ +        await self.apply_voice_ban(ctx, user, reason, expires_at=duration) +      # endregion      # region: Permanent shadow infractions @@ -225,6 +257,11 @@ class Infractions(InfractionScheduler, commands.Cog):          """Prematurely end the active ban infraction for the user."""          await self.pardon_infraction(ctx, "ban", user) +    @command(aliases=("uvban",)) +    async def unvoiceban(self, ctx: Context, user: FetchedMember) -> None: +        """Prematurely end the active voice ban infraction for the user.""" +        await self.pardon_infraction(ctx, "voice_ban", user) +      # endregion      # region: Base apply functions @@ -327,6 +364,26 @@ class Infractions(InfractionScheduler, commands.Cog):          bb_reason = "User has been permanently banned from the server. Automatically removed."          await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False) +    @respect_role_hierarchy(member_arg=2) +    async def apply_voice_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None: +        """Apply a voice ban infraction with kwargs passed to `post_infraction`.""" +        if await _utils.get_active_infraction(ctx, user, "voice_ban"): +            return + +        infraction = await _utils.post_infraction(ctx, user, "voice_ban", reason, active=True, **kwargs) +        if infraction is None: +            return + +        self.mod_log.ignore(Event.member_update, user.id) + +        if reason: +            reason = textwrap.shorten(reason, width=512, placeholder="...") + +        await user.move_to(None, reason="Disconnected from voice to apply voiceban.") + +        action = user.remove_roles(self._voice_verified_role, reason=reason) +        await self.apply_infraction(ctx, infraction, user, action) +      # endregion      # region: Base pardon functions @@ -371,6 +428,27 @@ class Infractions(InfractionScheduler, commands.Cog):          return log_text +    async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]: +        """Add Voice Verified role back to user, DM them a notification, and return a log dict.""" +        user = guild.get_member(user_id) +        log_text = {} + +        if user: +            # DM user about infraction expiration +            notified = await _utils.notify_pardon( +                user=user, +                title="Voice ban ended", +                content="You have been unbanned and can verify yourself again in the server.", +                icon_url=_utils.INFRACTION_ICONS["voice_ban"][1] +            ) + +            log_text["Member"] = format_user(user) +            log_text["DM"] = "Sent" if notified else "**Failed**" +        else: +            log_text["Info"] = "User was not found in the guild." + +        return log_text +      async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]:          """          Execute deactivation steps specific to the infraction's type and return a log dict. @@ -385,6 +463,8 @@ class Infractions(InfractionScheduler, commands.Cog):              return await self.pardon_mute(user_id, guild, reason)          elif infraction["type"] == "ban":              return await self.pardon_ban(user_id, guild, reason) +        elif infraction["type"] == "voice_ban": +            return await self.pardon_voice_ban(user_id, guild, reason)      # endregion diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py index cdab1a6c7..394f63da3 100644 --- a/bot/exts/moderation/infraction/management.py +++ b/bot/exts/moderation/infraction/management.py @@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions  from bot.exts.moderation.modlog import ModLog  from bot.pagination import LinePaginator  from bot.utils import messages, time -from bot.utils.checks import in_whitelist_check +from bot.utils.channel import is_mod_channel  log = logging.getLogger(__name__) @@ -295,13 +295,7 @@ class ModManagement(commands.Cog):          """Only allow moderators inside moderator channels to invoke the commands in this cog."""          checks = [              await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx), -            in_whitelist_check( -                ctx, -                channels=constants.MODERATION_CHANNELS, -                categories=[constants.Categories.modmail], -                redirect=None, -                fail_silently=True, -            ) +            is_mod_channel(ctx.channel)          ]          return all(checks) diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py index 229f991a0..e6712b3b6 100644 --- a/bot/exts/moderation/silence.py +++ b/bot/exts/moderation/silence.py @@ -1,8 +1,11 @@ -import asyncio +import json  import logging  from contextlib import suppress +from datetime import datetime, timedelta, timezone +from operator import attrgetter  from typing import Optional +from async_rediscache import RedisCache  from discord import TextChannel  from discord.ext import commands, tasks  from discord.ext.commands import Context @@ -10,10 +13,25 @@ from discord.ext.commands import Context  from bot.bot import Bot  from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles  from bot.converters import HushDurationConverter +from bot.utils.lock import LockedResourceError, lock_arg  from bot.utils.scheduling import Scheduler  log = logging.getLogger(__name__) +LOCK_NAMESPACE = "silence" + +MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced." +MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely." +MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)." + +MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced." +MSG_UNSILENCE_MANUAL = ( +    f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were " +    f"set manually or the cache was prematurely cleared. " +    f"Please edit the overwrites manually to unsilence." +) +MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel." +  class SilenceNotifier(tasks.Loop):      """Loop notifier for posting notices to `alert_channel` containing added channels.""" @@ -56,25 +74,32 @@ class SilenceNotifier(tasks.Loop):  class Silence(commands.Cog):      """Commands for stopping channel messages for `verified` role in a channel.""" +    # Maps muted channel IDs to their previous overwrites for send_message and add_reactions. +    # Overwrites are stored as JSON. +    previous_overwrites = RedisCache() + +    # Maps muted channel IDs to POSIX timestamps of when they'll be unsilenced. +    # A timestamp equal to -1 means it's indefinite. +    unsilence_timestamps = RedisCache() +      def __init__(self, bot: Bot):          self.bot = bot          self.scheduler = Scheduler(self.__class__.__name__) -        self.muted_channels = set() -        self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars()) -        self._get_instance_vars_event = asyncio.Event() +        self._init_task = self.bot.loop.create_task(self._async_init()) -    async def _get_instance_vars(self) -> None: -        """Get instance variables after they're available to get from the guild.""" +    async def _async_init(self) -> None: +        """Set instance attributes once the guild is available and reschedule unsilences."""          await self.bot.wait_until_guild_available() +          guild = self.bot.get_guild(Guild.id)          self._verified_role = guild.get_role(Roles.verified)          self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts) -        self._mod_log_channel = self.bot.get_channel(Channels.mod_log) -        self.notifier = SilenceNotifier(self._mod_log_channel) -        self._get_instance_vars_event.set() +        self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log)) +        await self._reschedule()      @commands.command(aliases=("hush",)) +    @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True)      async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None:          """          Silence the current channel for `duration` minutes or `forever`. @@ -82,18 +107,25 @@ class Silence(commands.Cog):          Duration is capped at 15 minutes, passing forever makes the silence indefinite.          Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start.          """ -        await self._get_instance_vars_event.wait() -        log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.") -        if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration): -            await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.") -            return -        if duration is None: -            await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.") +        await self._init_task + +        channel_info = f"#{ctx.channel} ({ctx.channel.id})" +        log.debug(f"{ctx.author} is silencing channel {channel_info}.") + +        if not await self._set_silence_overwrites(ctx.channel): +            log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.") +            await ctx.send(MSG_SILENCE_FAIL)              return -        await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).") +        await self._schedule_unsilence(ctx, duration) -        self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) +        if duration is None: +            self.notifier.add_channel(ctx.channel) +            log.info(f"Silenced {channel_info} indefinitely.") +            await ctx.send(MSG_SILENCE_PERMANENT) +        else: +            log.info(f"Silenced {channel_info} for {duration} minute(s).") +            await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration))      @commands.command(aliases=("unhush",))      async def unsilence(self, ctx: Context) -> None: @@ -102,61 +134,115 @@ class Silence(commands.Cog):          If the channel was silenced indefinitely, notifications for the channel will stop.          """ -        await self._get_instance_vars_event.wait() +        await self._init_task          log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.") -        if not await self._unsilence(ctx.channel): -            await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.") +        await self._unsilence_wrapper(ctx.channel) + +    @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True) +    async def _unsilence_wrapper(self, channel: TextChannel) -> None: +        """Unsilence `channel` and send a success/failure message.""" +        if not await self._unsilence(channel): +            overwrite = channel.overwrites_for(self._verified_role) +            if overwrite.send_messages is False or overwrite.add_reactions is False: +                await channel.send(MSG_UNSILENCE_MANUAL) +            else: +                await channel.send(MSG_UNSILENCE_FAIL)          else: -            await ctx.send(f"{Emojis.check_mark} unsilenced current channel.") +            await channel.send(MSG_UNSILENCE_SUCCESS) -    async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool: -        """ -        Silence `channel` for `self._verified_role`. +    async def _set_silence_overwrites(self, channel: TextChannel) -> bool: +        """Set silence permission overwrites for `channel` and return True if successful.""" +        overwrite = channel.overwrites_for(self._verified_role) +        prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions) -        If `persistent` is `True` add `channel` to notifier. -        `duration` is only used for logging; if None is passed `persistent` should be True to not log None. -        Return `True` if channel permissions were changed, `False` otherwise. -        """ -        current_overwrite = channel.overwrites_for(self._verified_role) -        if current_overwrite.send_messages is False: -            log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.") +        if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):              return False -        await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False)) -        self.muted_channels.add(channel) -        if persistent: -            log.info(f"Silenced #{channel} ({channel.id}) indefinitely.") -            self.notifier.add_channel(channel) -            return True - -        log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).") + +        overwrite.update(send_messages=False, add_reactions=False) +        await channel.set_permissions(self._verified_role, overwrite=overwrite) +        await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites)) +          return True +    async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None: +        """Schedule `ctx.channel` to be unsilenced if `duration` is not None.""" +        if duration is None: +            await self.unsilence_timestamps.set(ctx.channel.id, -1) +        else: +            self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence)) +            unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration) +            await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp()) +      async def _unsilence(self, channel: TextChannel) -> bool:          """          Unsilence `channel`. -        Check if `channel` is silenced through a `PermissionOverwrite`, -        if it is unsilence it and remove it from the notifier. +        If `channel` has a silence task scheduled or has its previous overwrites cached, unsilence +        it, cancel the task, and remove it from the notifier. Notify admins if it has a task but +        not cached overwrites. +          Return `True` if channel permissions were changed, `False` otherwise.          """ -        current_overwrite = channel.overwrites_for(self._verified_role) -        if current_overwrite.send_messages is False: -            await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None)) -            log.info(f"Unsilenced channel #{channel} ({channel.id}).") -            self.scheduler.cancel(channel.id) -            self.notifier.remove_channel(channel) -            self.muted_channels.discard(channel) -            return True -        log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") -        return False +        prev_overwrites = await self.previous_overwrites.get(channel.id) +        if channel.id not in self.scheduler and prev_overwrites is None: +            log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.") +            return False + +        overwrite = channel.overwrites_for(self._verified_role) +        if prev_overwrites is None: +            log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.") +            overwrite.update(send_messages=None, add_reactions=None) +        else: +            overwrite.update(**json.loads(prev_overwrites)) + +        await channel.set_permissions(self._verified_role, overwrite=overwrite) +        log.info(f"Unsilenced channel #{channel} ({channel.id}).") + +        self.scheduler.cancel(channel.id) +        self.notifier.remove_channel(channel) +        await self.previous_overwrites.delete(channel.id) +        await self.unsilence_timestamps.delete(channel.id) + +        if prev_overwrites is None: +            await self._mod_alerts_channel.send( +                f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing " +                f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` " +                f"overwrites for {self._verified_role.mention} are at their desired values." +            ) + +        return True + +    async def _reschedule(self) -> None: +        """Reschedule unsilencing of active silences and add permanent ones to the notifier.""" +        for channel_id, timestamp in await self.unsilence_timestamps.items(): +            channel = self.bot.get_channel(channel_id) +            if channel is None: +                log.info(f"Can't reschedule silence for {channel_id}: channel not found.") +                continue + +            if timestamp == -1: +                log.info(f"Adding permanent silence for #{channel} ({channel.id}) to the notifier.") +                self.notifier.add_channel(channel) +                continue + +            dt = datetime.fromtimestamp(timestamp, tz=timezone.utc) +            delta = (dt - datetime.now(tz=timezone.utc)).total_seconds() +            if delta <= 0: +                # Suppress the error since it's not being invoked by a user via the command. +                with suppress(LockedResourceError): +                    await self._unsilence_wrapper(channel) +            else: +                log.info(f"Rescheduling silence for #{channel} ({channel.id}).") +                self.scheduler.schedule_later(delta, channel_id, self._unsilence_wrapper(channel))      def cog_unload(self) -> None: -        """Send alert with silenced channels and cancel scheduled tasks on unload.""" -        self.scheduler.cancel_all() -        if self.muted_channels: -            channels_string = ''.join(channel.mention for channel in self.muted_channels) -            message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}" -            self.bot.closing_tasks.append(asyncio.create_task(self._mod_alerts_channel.send(message))) +        """Cancel the init task and scheduled tasks.""" +        # It's important to wait for _init_task (specifically for _reschedule) to be cancelled +        # before cancelling scheduled tasks. Otherwise, it's possible for _reschedule to schedule +        # more tasks after cancel_all has finished, despite _init_task.cancel being called first. +        # This is cause cancel() on its own doesn't block until the task is cancelled. +        self._init_task.cancel() +        self._init_task.add_done_callback(lambda _: self.scheduler.cancel_all())      # This cannot be static (must have a __func__ attribute).      async def cog_check(self, ctx: Context) -> bool: diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py index c3ad8687e..c599156d0 100644 --- a/bot/exts/moderation/verification.py +++ b/bot/exts/moderation/verification.py @@ -11,6 +11,7 @@ from discord.ext.commands import Cog, Context, command, group, has_any_role  from discord.utils import snowflake_time  from bot import constants +from bot.api import ResponseCodeError  from bot.bot import Bot  from bot.decorators import has_no_roles, in_whitelist  from bot.exts.moderation.modlog import ModLog @@ -355,6 +356,28 @@ class Verification(Cog):          return n_success +    async def _add_kick_note(self, member: discord.Member) -> None: +        """ +        Post a note regarding `member` being kicked to site. + +        Allows keeping track of kicked members for auditing purposes. +        """ +        payload = { +            "active": False, +            "actor": self.bot.user.id,  # Bot actions this autonomously +            "expires_at": None, +            "hidden": True, +            "reason": "Verification kick", +            "type": "note", +            "user": member.id, +        } + +        log.trace(f"Posting kick note for member {member} ({member.id})") +        try: +            await self.bot.api_client.post("bot/infractions", json=payload) +        except ResponseCodeError as api_exc: +            log.warning("Failed to post kick note", exc_info=api_exc) +      async def _kick_members(self, members: t.Collection[discord.Member]) -> int:          """          Kick `members` from the PyDis guild. @@ -373,6 +396,7 @@ class Verification(Cog):              except discord.HTTPException as suspicious_exception:                  raise StopExecution(reason=suspicious_exception)              await member.kick(reason=f"User has not verified in {constants.Verification.kicked_after} days") +            await self._add_kick_note(member)          n_kicked = await self._send_requests(members, kick_request, Limit(batch_size=2, sleep_secs=1))          self.bot.stats.incr("verification.kicked", count=n_kicked) @@ -547,6 +571,16 @@ class Verification(Cog):          # video.          if raw_member.get("is_pending"):              await self.member_gating_cache.set(member.id, True) + +            # TODO: Temporary, remove soon after asking joe. +            await self.mod_log.send_log_message( +                icon_url=self.bot.user.avatar_url, +                colour=discord.Colour.blurple(), +                title="New native gated user", +                channel_id=constants.Channels.user_log, +                text=f"<@{member.id}> ({member.id})", +            ) +              return          log.trace(f"Sending on join message to new member: {member.id}") diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py new file mode 100644 index 000000000..93d96693c --- /dev/null +++ b/bot/exts/moderation/voice_gate.py @@ -0,0 +1,171 @@ +import asyncio +import logging +from contextlib import suppress +from datetime import datetime, timedelta + +import discord +from dateutil import parser +from discord import Colour +from discord.ext.commands import Cog, Context, command + +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.constants import Channels, Event, MODERATION_ROLES, Roles, VoiceGate as GateConf +from bot.decorators import has_no_roles, in_whitelist +from bot.exts.moderation.modlog import ModLog +from bot.utils.checks import InWhitelistCheckFailure + +log = logging.getLogger(__name__) + +FAILED_MESSAGE = ( +    """You are not currently eligible to use voice inside Python Discord for the following reasons:\n\n{reasons}""" +) + +MESSAGE_FIELD_MAP = { +    "verified_at": f"have been verified for less than {GateConf.minimum_days_verified} days", +    "voice_banned": "have an active voice ban infraction", +    "total_messages": f"have sent less than {GateConf.minimum_messages} messages", +    "activity_blocks": f"have been active for fewer than {GateConf.minimum_activity_blocks} ten-minute blocks", +} + + +class VoiceGate(Cog): +    """Voice channels verification management.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @property +    def mod_log(self) -> ModLog: +        """Get the currently loaded ModLog cog instance.""" +        return self.bot.get_cog("ModLog") + +    @command(aliases=('voiceverify',)) +    @has_no_roles(Roles.voice_verified) +    @in_whitelist(channels=(Channels.voice_gate,), redirect=None) +    async def voice_verify(self, ctx: Context, *_) -> None: +        """ +        Apply to be able to use voice within the Discord server. + +        In order to use voice you must meet all three of the following criteria: +        - You must have over a certain number of messages within the Discord server +        - You must have accepted our rules over a certain number of days ago +        - You must not be actively banned from using our voice channels +        - You must have been active for over a certain number of 10-minute blocks +        """ +        try: +            data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data") +        except ResponseCodeError as e: +            if e.status == 404: +                embed = discord.Embed( +                    title="Not found", +                    description=( +                        "We were unable to find user data for you. " +                        "Please try again shortly, " +                        "if this problem persists please contact the server staff through Modmail." +                    ), +                    color=Colour.red() +                ) +                log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})") +            else: +                embed = discord.Embed( +                    title="Unexpected response", +                    description=( +                        "We encountered an error while attempting to find data for your user. " +                        "Please try again and let us know if the problem persists." +                    ), +                    color=Colour.red() +                ) +                log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.") + +            await ctx.author.send(embed=embed) +            return + +        # Pre-parse this for better code style +        if data["verified_at"] is not None: +            data["verified_at"] = parser.isoparse(data["verified_at"]) +        else: +            data["verified_at"] = datetime.utcnow() - timedelta(days=3) + +        checks = { +            "verified_at": data["verified_at"] > datetime.utcnow() - timedelta(days=GateConf.minimum_days_verified), +            "total_messages": data["total_messages"] < GateConf.minimum_messages, +            "voice_banned": data["voice_banned"], +            "activity_blocks": data["activity_blocks"] < GateConf.minimum_activity_blocks +        } +        failed = any(checks.values()) +        failed_reasons = [MESSAGE_FIELD_MAP[key] for key, value in checks.items() if value is True] +        [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True] + +        if failed: +            embed = discord.Embed( +                title="Voice Gate failed", +                description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)), +                color=Colour.red() +            ) +            try: +                await ctx.author.send(embed=embed) +                await ctx.send(f"{ctx.author}, please check your DMs.") +            except discord.Forbidden: +                await ctx.channel.send(ctx.author.mention, embed=embed) +            return + +        self.mod_log.ignore(Event.member_update, ctx.author.id) +        embed = discord.Embed( +            title="Voice gate passed", +            description="You have been granted permission to use voice channels in Python Discord.", +            color=Colour.green() +        ) + +        if ctx.author.voice: +            embed.description += "\n\nPlease reconnect to your voice channel to be granted your new permissions." + +        try: +            await ctx.author.send(embed=embed) +            await ctx.send(f"{ctx.author}, please check your DMs.") +        except discord.Forbidden: +            await ctx.channel.send(ctx.author.mention, embed=embed) + +        # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it. +        await asyncio.sleep(3) +        await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed") + +        self.bot.stats.incr("voice_gate.passed") + +    @Cog.listener() +    async def on_message(self, message: discord.Message) -> None: +        """Delete all non-staff messages from voice gate channel that don't invoke voice verify command.""" +        # Check is channel voice gate +        if message.channel.id != Channels.voice_gate: +            return + +        ctx = await self.bot.get_context(message) +        is_verify_command = ctx.command is not None and ctx.command.name == "voice_verify" + +        # When it's bot sent message, delete it after some time +        if message.author.bot: +            with suppress(discord.NotFound): +                await message.delete(delay=GateConf.bot_message_delete_delay) +                return + +        # Then check is member moderator+, because we don't want to delete their messages. +        if any(role.id in MODERATION_ROLES for role in message.author.roles) and is_verify_command is False: +            log.trace(f"Excluding moderator message {message.id} from deletion in #{message.channel}.") +            return + +        # Ignore deleted voice verification messages +        if ctx.command is not None and ctx.command.name == "voice_verify": +            self.mod_log.ignore(Event.message_delete, message.id) + +        with suppress(discord.NotFound): +            await message.delete() + +    async def cog_command_error(self, ctx: Context, error: Exception) -> None: +        """Check for & ignore any InWhitelistCheckFailure.""" +        if isinstance(error, InWhitelistCheckFailure): +            error.handled = True + + +def setup(bot: Bot) -> None: +    """Loads the VoiceGate cog.""" +    bot.add_cog(VoiceGate(bot)) diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py index ba1fd2a5c..69d623581 100644 --- a/bot/exts/utils/bot.py +++ b/bot/exts/utils/bot.py @@ -1,22 +1,14 @@ -import ast  import logging -import re -import time -from typing import Optional, Tuple +from typing import Optional -from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord import Embed, TextChannel  from discord.ext.commands import Cog, Context, command, group, has_any_role  from bot.bot import Bot -from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs -from bot.exts.filters.token_remover import TokenRemover -from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE -from bot.utils.messages import wait_for_deletion +from bot.constants import Guild, MODERATION_ROLES, Roles, URLs  log = logging.getLogger(__name__) -RE_MARKDOWN = re.compile(r'([*_~`|>])') -  class BotCog(Cog, name="Bot"):      """Bot information commands.""" @@ -24,19 +16,6 @@ class BotCog(Cog, name="Bot"):      def __init__(self, bot: Bot):          self.bot = bot -        # Stores allowed channels plus epoch time since last call. -        self.channel_cooldowns = { -            Channels.python_discussion: 0, -        } - -        # These channels will also work, but will not be subject to cooldown -        self.channel_whitelist = ( -            Channels.bot_commands, -        ) - -        # Stores improperly formatted Python codeblock message ids and the corresponding bot message -        self.codeblock_message_ids = {} -      @group(invoke_without_command=True, name="bot", hidden=True)      @has_any_role(Roles.verified)      async def botinfo_group(self, ctx: Context) -> None: @@ -81,305 +60,6 @@ class BotCog(Cog, name="Bot"):          else:              await channel.send(embed=embed) -    def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]: -        """ -        Strip msg in order to find Python code. - -        Tries to strip out Python code out of msg and returns the stripped block or -        None if the block is a valid Python codeblock. -        """ -        if msg.count("\n") >= 3: -            # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found. -            if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks: -                log.trace( -                    "Someone wrote a message that was already a " -                    "valid Python syntax highlighted code block. No action taken." -                ) -                return None - -            else: -                # Stripping backticks from every line of the message. -                log.trace(f"Stripping backticks from message.\n\n{msg}\n\n") -                content = "" -                for line in msg.splitlines(keepends=True): -                    content += line.strip("`") - -                content = content.strip() - -                # Remove "Python" or "Py" from start of the message if it exists. -                log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n") -                pycode = False -                if content.lower().startswith("python"): -                    content = content[6:] -                    pycode = True -                elif content.lower().startswith("py"): -                    content = content[2:] -                    pycode = True - -                if pycode: -                    content = content.splitlines(keepends=True) - -                    # Check if there might be code in the first line, and preserve it. -                    first_line = content[0] -                    if " " in content[0]: -                        first_space = first_line.index(" ") -                        content[0] = first_line[first_space:] -                        content = "".join(content) - -                    # If there's no code we can just get rid of the first line. -                    else: -                        content = "".join(content[1:]) - -                # Strip it again to remove any leading whitespace. This is necessary -                # if the first line of the message looked like ```python <code> -                old = content.strip() - -                # Strips REPL code out of the message if there is any. -                content, repl_code = self.repl_stripping(old) -                if old != content: -                    return (content, old), repl_code - -                # Try to apply indentation fixes to the code. -                content = self.fix_indentation(content) - -                # Check if the code contains backticks, if it does ignore the message. -                if "`" in content: -                    log.trace("Detected ` inside the code, won't reply") -                    return None -                else: -                    log.trace(f"Returning message.\n\n{content}\n\n") -                    return (content,), repl_code - -    def fix_indentation(self, msg: str) -> str: -        """Attempts to fix badly indented code.""" -        def unindent(code: str, skip_spaces: int = 0) -> str: -            """Unindents all code down to the number of spaces given in skip_spaces.""" -            final = "" -            current = code[0] -            leading_spaces = 0 - -            # Get numbers of spaces before code in the first line. -            while current == " ": -                current = code[leading_spaces + 1] -                leading_spaces += 1 -            leading_spaces -= skip_spaces - -            # If there are any, remove that number of spaces from every line. -            if leading_spaces > 0: -                for line in code.splitlines(keepends=True): -                    line = line[leading_spaces:] -                    final += line -                return final -            else: -                return code - -        # Apply fix for "all lines are overindented" case. -        msg = unindent(msg) - -        # If the first line does not end with a colon, we can be -        # certain the next line will be on the same indentation level. -        # -        # If it does end with a colon, we will need to indent all successive -        # lines one additional level. -        first_line = msg.splitlines()[0] -        code = "".join(msg.splitlines(keepends=True)[1:]) -        if not first_line.endswith(":"): -            msg = f"{first_line}\n{unindent(code)}" -        else: -            msg = f"{first_line}\n{unindent(code, 4)}" -        return msg - -    def repl_stripping(self, msg: str) -> Tuple[str, bool]: -        """ -        Strip msg in order to extract Python code out of REPL output. - -        Tries to strip out REPL Python code out of msg and returns the stripped msg. - -        Returns True for the boolean if REPL code was found in the input msg. -        """ -        final = "" -        for line in msg.splitlines(keepends=True): -            if line.startswith(">>>") or line.startswith("..."): -                final += line[4:] -        log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n") -        if not final: -            log.trace(f"Found no REPL code in \n\n{msg}\n\n") -            return msg, False -        else: -            log.trace(f"Found REPL code in \n\n{msg}\n\n") -            return final.rstrip(), True - -    def has_bad_ticks(self, msg: Message) -> bool: -        """Check to see if msg contains ticks that aren't '`'.""" -        not_backticks = [ -            "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019", -            "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033", -            "\u3003\u3003\u3003" -        ] - -        return msg.content[:3] in not_backticks - -    @Cog.listener() -    async def on_message(self, msg: Message) -> None: -        """ -        Detect poorly formatted Python code in new messages. - -        If poorly formatted code is detected, send the user a helpful message explaining how to do -        properly formatted Python syntax highlighting codeblocks. -        """ -        is_help_channel = ( -            getattr(msg.channel, "category", None) -            and msg.channel.category.id in (Categories.help_available, Categories.help_in_use) -        ) -        parse_codeblock = ( -            ( -                is_help_channel -                or msg.channel.id in self.channel_cooldowns -                or msg.channel.id in self.channel_whitelist -            ) -            and not msg.author.bot -            and len(msg.content.splitlines()) > 3 -            and not TokenRemover.find_token_in_message(msg) -            and not WEBHOOK_URL_RE.search(msg.content) -        ) - -        if parse_codeblock:  # no token in the msg -            on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 -            if not on_cooldown or DEBUG_MODE: -                try: -                    if self.has_bad_ticks(msg): -                        ticks = msg.content[:3] -                        content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True) -                        if content is None: -                            return - -                        content, repl_code = content - -                        if len(content) == 2: -                            content = content[1] -                        else: -                            content = content[0] - -                        space_left = 204 -                        if len(content) >= space_left: -                            current_length = 0 -                            lines_walked = 0 -                            for line in content.splitlines(keepends=True): -                                if current_length + len(line) > space_left or lines_walked == 10: -                                    break -                                current_length += len(line) -                                lines_walked += 1 -                            content = content[:current_length] + "#..." -                        content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) -                        howto = ( -                            "It looks like you are trying to paste code into this channel.\n\n" -                            "You seem to be using the wrong symbols to indicate where the codeblock should start. " -                            f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n" -                            "**Here is an example of how it should look:**\n" -                            f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" -                            "**This will result in the following:**\n" -                            f"```python\n{content}\n```" -                        ) - -                    else: -                        howto = "" -                        content = self.codeblock_stripping(msg.content, False) -                        if content is None: -                            return - -                        content, repl_code = content -                        # Attempts to parse the message into an AST node. -                        # Invalid Python code will raise a SyntaxError. -                        tree = ast.parse(content[0]) - -                        # Multiple lines of single words could be interpreted as expressions. -                        # This check is to avoid all nodes being parsed as expressions. -                        # (e.g. words over multiple lines) -                        if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code: -                            # Shorten the code to 10 lines and/or 204 characters. -                            space_left = 204 -                            if content and repl_code: -                                content = content[1] -                            else: -                                content = content[0] - -                            if len(content) >= space_left: -                                current_length = 0 -                                lines_walked = 0 -                                for line in content.splitlines(keepends=True): -                                    if current_length + len(line) > space_left or lines_walked == 10: -                                        break -                                    current_length += len(line) -                                    lines_walked += 1 -                                content = content[:current_length] + "#..." - -                            content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content) -                            howto += ( -                                "It looks like you're trying to paste code into this channel.\n\n" -                                "Discord has support for Markdown, which allows you to post code with full " -                                "syntax highlighting. Please use these whenever you paste code, as this " -                                "helps improve the legibility and makes it easier for us to help you.\n\n" -                                f"**To do this, use the following method:**\n" -                                f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n" -                                "**This will result in the following:**\n" -                                f"```python\n{content}\n```" -                            ) - -                            log.debug(f"{msg.author} posted something that needed to be put inside python code " -                                      "blocks. Sending the user some instructions.") -                        else: -                            log.trace("The code consists only of expressions, not sending instructions") - -                    if howto != "": -                        # Increase amount of codeblock correction in stats -                        self.bot.stats.incr("codeblock_corrections") -                        howto_embed = Embed(description=howto) -                        bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed) -                        self.codeblock_message_ids[msg.id] = bot_message.id - -                        self.bot.loop.create_task( -                            wait_for_deletion(bot_message, (msg.author.id,), self.bot) -                        ) -                    else: -                        return - -                    if msg.channel.id not in self.channel_whitelist: -                        self.channel_cooldowns[msg.channel.id] = time.time() - -                except SyntaxError: -                    log.trace( -                        f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, " -                        "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. " -                        f"The message that was posted was:\n\n{msg.content}\n\n" -                    ) - -    @Cog.listener() -    async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None: -        """Check to see if an edited message (previously called out) still contains poorly formatted code.""" -        if ( -            # Checks to see if the message was called out by the bot -            payload.message_id not in self.codeblock_message_ids -            # Makes sure that there is content in the message -            or payload.data.get("content") is None -            # Makes sure there's a channel id in the message payload -            or payload.data.get("channel_id") is None -        ): -            return - -        # Retrieve channel and message objects for use later -        channel = self.bot.get_channel(int(payload.data.get("channel_id"))) -        user_message = await channel.fetch_message(payload.message_id) - -        #  Checks to see if the user has corrected their codeblock.  If it's fixed, has_fixed_codeblock will be None -        has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message)) - -        # If the message is fixed, delete the bot message and the entry from the id dictionary -        if has_fixed_codeblock is None: -            bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) -            await bot_message.delete() -            del self.codeblock_message_ids[payload.message_id] -            log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") -  def setup(bot: Bot) -> None:      """Load the Bot cog.""" diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index bf4e24661..3113a1149 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -23,7 +23,7 @@ from bot.utils.time import humanize_delta  log = logging.getLogger(__name__) -NAMESPACE = "reminder"  # Used for the mutually_exclusive decorator; constant to prevent typos +LOCK_NAMESPACE = "reminder"  WHITELISTED_CHANNELS = Guild.reminder_whitelist  MAXIMUM_REMINDERS = 5 @@ -170,7 +170,7 @@ class Reminders(Cog):          log.trace(f"Scheduling new task #{reminder['id']}")          self.schedule_reminder(reminder) -    @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True)      async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:          """Send the reminder."""          is_valid, user, channel = self.ensure_valid_reminder(reminder) @@ -378,7 +378,7 @@ class Reminders(Cog):          mention_ids = [mention.id for mention in mentions]          await self.edit_reminder(ctx, id_, {"mentions": mention_ids}) -    @lock_arg(NAMESPACE, "id_", raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)      async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:          """Edits a reminder with the given payload, then sends a confirmation message."""          if not await self._can_modify(ctx, id_): @@ -398,7 +398,7 @@ class Reminders(Cog):          await self._reschedule_reminder(reminder)      @remind_group.command("delete", aliases=("remove", "cancel")) -    @lock_arg(NAMESPACE, "id_", raise_error=True) +    @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)      async def delete_reminder(self, ctx: Context, id_: int) -> None:          """Delete one of your active reminders."""          if not await self._can_modify(ctx, id_): diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index ca6fbf5cb..41cb00541 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -21,14 +21,12 @@ log = logging.getLogger(__name__)  ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")  FORMATTED_CODE_REGEX = re.compile( -    r"^\s*"                                 # any leading whitespace from the beginning of the string      r"(?P<delim>(?P<block>```)|``?)"        # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block      r"(?(block)(?:(?P<lang>[a-z]+)\n)?)"    # if we're in a block, match optional language (only letters plus newline)      r"(?:[ \t]*\n)*"                        # any blank (empty or tabs/spaces only) lines before the code      r"(?P<code>.*?)"                        # extract all code inside the markup      r"\s*"                                  # any more whitespace before the end of the code markup -    r"(?P=delim)"                           # match the exact same delimiter from the start again -    r"\s*$",                                # any trailing whitespace until the end of the string +    r"(?P=delim)",                          # match the exact same delimiter from the start again      re.DOTALL | re.IGNORECASE               # "." also matches newlines, case insensitive  )  RAW_CODE_REGEX = re.compile( @@ -38,11 +36,11 @@ RAW_CODE_REGEX = re.compile(      re.DOTALL                               # "." also matches newlines  ) -MAX_PASTE_LEN = 1000 +MAX_PASTE_LEN = 10000  # `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric, Channels.code_help_voice) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use) +EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) +EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice)  EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)  SIGKILL = 9 @@ -76,23 +74,32 @@ class Snekbox(Cog):      @staticmethod      def prepare_input(code: str) -> str: -        """Extract code from the Markdown, format it, and insert it into the code template.""" -        match = FORMATTED_CODE_REGEX.fullmatch(code) -        if match: -            code, block, lang, delim = match.group("code", "block", "lang", "delim") -            code = textwrap.dedent(code) -            if block: -                info = (f"'{lang}' highlighted" if lang else "plain") + " code block" +        """ +        Extract code from the Markdown, format it, and insert it into the code template. + +        If there is any code block, ignore text outside the code block. +        Use the first code block, but prefer a fenced code block. +        If there are several fenced code blocks, concatenate only the fenced code blocks. +        """ +        if match := list(FORMATTED_CODE_REGEX.finditer(code)): +            blocks = [block for block in match if block.group("block")] + +            if len(blocks) > 1: +                code = '\n'.join(block.group("code") for block in blocks) +                info = "several code blocks"              else: -                info = f"{delim}-enclosed inline code" -            log.trace(f"Extracted {info} for evaluation:\n{code}") +                match = match[0] if len(blocks) == 0 else blocks[0] +                code, block, lang, delim = match.group("code", "block", "lang", "delim") +                if block: +                    info = (f"'{lang}' highlighted" if lang else "plain") + " code block" +                else: +                    info = f"{delim}-enclosed inline code"          else: -            code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code")) -            log.trace( -                f"Eval message contains unformatted or badly formatted code, " -                f"stripping whitespace only:\n{code}" -            ) +            code = RAW_CODE_REGEX.fullmatch(code).group("code") +            info = "unformatted or badly formatted code" +        code = textwrap.dedent(code) +        log.trace(f"Extracted {info} for evaluation:\n{code}")          return code      @staticmethod diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index 3e9230414..6d8d98695 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -2,9 +2,10 @@ import difflib  import logging  import re  import unicodedata +from datetime import datetime, timedelta  from email.parser import HeaderParser  from io import StringIO -from typing import Tuple, Union +from typing import Dict, Optional, Tuple, Union  from discord import Colour, Embed, utils  from discord.ext.commands import BadArgument, Cog, Context, clean_content, command, has_any_role @@ -14,6 +15,7 @@ from bot.constants import Channels, MODERATION_ROLES, STAFF_ROLES  from bot.decorators import in_whitelist  from bot.pagination import LinePaginator  from bot.utils import messages +from bot.utils.cache import AsyncCache  log = logging.getLogger(__name__) @@ -41,80 +43,21 @@ Namespaces are one honking great idea -- let's do more of those!  ICON_URL = "https://www.python.org/static/opengraph-icon-200x200.png" +pep_cache = AsyncCache() +  class Utils(Cog):      """A selection of utilities which don't have a clear category.""" +    BASE_PEP_URL = "http://www.python.org/dev/peps/pep-" +    BASE_GITHUB_PEP_URL = "https://raw.githubusercontent.com/python/peps/master/pep-" +    PEPS_LISTING_API_URL = "https://api.github.com/repos/python/peps/contents?ref=master" +      def __init__(self, bot: Bot):          self.bot = bot - -        self.base_pep_url = "http://www.python.org/dev/peps/pep-" -        self.base_github_pep_url = "https://raw.githubusercontent.com/python/peps/master/pep-" - -    @command(name='pep', aliases=('get_pep', 'p')) -    async def pep_command(self, ctx: Context, pep_number: str) -> None: -        """Fetches information about a PEP and sends it to the channel.""" -        if pep_number.isdigit(): -            pep_number = int(pep_number) -        else: -            await ctx.send_help(ctx.command) -            return - -        # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. -        if pep_number == 0: -            return await self.send_pep_zero(ctx) - -        possible_extensions = ['.txt', '.rst'] -        found_pep = False -        for extension in possible_extensions: -            # Attempt to fetch the PEP -            pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}" -            log.trace(f"Requesting PEP {pep_number} with {pep_url}") -            response = await self.bot.http_session.get(pep_url) - -            if response.status == 200: -                log.trace("PEP found") -                found_pep = True - -                pep_content = await response.text() - -                # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 -                pep_header = HeaderParser().parse(StringIO(pep_content)) - -                # Assemble the embed -                pep_embed = Embed( -                    title=f"**PEP {pep_number} - {pep_header['Title']}**", -                    url=f"{self.base_pep_url}{pep_number:04}" -                ) - -                pep_embed.set_thumbnail(url=ICON_URL) - -                # Add the interesting information -                fields_to_check = ("Status", "Python-Version", "Created", "Type") -                for field in fields_to_check: -                    # Check for a PEP metadata field that is present but has an empty value -                    # embed field values can't contain an empty string -                    if pep_header.get(field, ""): -                        pep_embed.add_field(name=field, value=pep_header[field]) - -            elif response.status != 404: -                # any response except 200 and 404 is expected -                found_pep = True  # actually not, but it's easier to display this way -                log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: " -                          f"{response.status}.\n{response.text}") - -                error_message = "Unexpected HTTP error during PEP search. Please let us know." -                pep_embed = Embed(title="Unexpected error", description=error_message) -                pep_embed.colour = Colour.red() -                break - -        if not found_pep: -            log.trace("PEP was not found") -            not_found = f"PEP {pep_number} does not exist." -            pep_embed = Embed(title="PEP not found", description=not_found) -            pep_embed.colour = Colour.red() - -        await ctx.message.channel.send(embed=pep_embed) +        self.peps: Dict[int, str] = {} +        self.last_refreshed_peps: Optional[datetime] = None +        self.bot.loop.create_task(self.refresh_peps_urls())      @command()      @in_whitelist(channels=(Channels.bot_commands,), roles=STAFF_ROLES) @@ -246,8 +189,53 @@ class Utils(Cog):          for reaction in options:              await message.add_reaction(reaction) -    async def send_pep_zero(self, ctx: Context) -> None: -        """Send information about PEP 0.""" +    # region: PEP + +    async def refresh_peps_urls(self) -> None: +        """Refresh PEP URLs listing in every 3 hours.""" +        # Wait until HTTP client is available +        await self.bot.wait_until_ready() +        log.trace("Started refreshing PEP URLs.") + +        async with self.bot.http_session.get(self.PEPS_LISTING_API_URL) as resp: +            listing = await resp.json() + +        log.trace("Got PEP URLs listing from GitHub API") + +        for file in listing: +            name = file["name"] +            if name.startswith("pep-") and name.endswith((".rst", ".txt")): +                pep_number = name.replace("pep-", "").split(".")[0] +                self.peps[int(pep_number)] = file["download_url"] + +        self.last_refreshed_peps = datetime.now() +        log.info("Successfully refreshed PEP URLs listing.") + +    @command(name='pep', aliases=('get_pep', 'p')) +    async def pep_command(self, ctx: Context, pep_number: int) -> None: +        """Fetches information about a PEP and sends it to the channel.""" +        # Trigger typing in chat to show users that bot is responding +        await ctx.trigger_typing() + +        # Handle PEP 0 directly because it's not in .rst or .txt so it can't be accessed like other PEPs. +        if pep_number == 0: +            pep_embed = self.get_pep_zero_embed() +            success = True +        else: +            success = False +            if not (pep_embed := await self.validate_pep_number(pep_number)): +                pep_embed, success = await self.get_pep_embed(pep_number) + +        await ctx.send(embed=pep_embed) +        if success: +            log.trace(f"PEP {pep_number} getting and sending finished successfully. Increasing stat.") +            self.bot.stats.incr(f"pep_fetches.{pep_number}") +        else: +            log.trace(f"Getting PEP {pep_number} failed. Error embed sent.") + +    @staticmethod +    def get_pep_zero_embed() -> Embed: +        """Get information embed about PEP 0."""          pep_embed = Embed(              title="**PEP 0 - Index of Python Enhancement Proposals (PEPs)**",              url="https://www.python.org/dev/peps/" @@ -257,7 +245,69 @@ class Utils(Cog):          pep_embed.add_field(name="Created", value="13-Jul-2000")          pep_embed.add_field(name="Type", value="Informational") -        await ctx.send(embed=pep_embed) +        return pep_embed + +    async def validate_pep_number(self, pep_nr: int) -> Optional[Embed]: +        """Validate is PEP number valid. When it isn't, return error embed, otherwise None.""" +        if ( +            pep_nr not in self.peps +            and (self.last_refreshed_peps + timedelta(minutes=30)) <= datetime.now() +            and len(str(pep_nr)) < 5 +        ): +            await self.refresh_peps_urls() + +        if pep_nr not in self.peps: +            log.trace(f"PEP {pep_nr} was not found") +            return Embed( +                title="PEP not found", +                description=f"PEP {pep_nr} does not exist.", +                colour=Colour.red() +            ) + +        return None + +    def generate_pep_embed(self, pep_header: Dict, pep_nr: int) -> Embed: +        """Generate PEP embed based on PEP headers data.""" +        # Assemble the embed +        pep_embed = Embed( +            title=f"**PEP {pep_nr} - {pep_header['Title']}**", +            description=f"[Link]({self.BASE_PEP_URL}{pep_nr:04})", +        ) + +        pep_embed.set_thumbnail(url=ICON_URL) + +        # Add the interesting information +        fields_to_check = ("Status", "Python-Version", "Created", "Type") +        for field in fields_to_check: +            # Check for a PEP metadata field that is present but has an empty value +            # embed field values can't contain an empty string +            if pep_header.get(field, ""): +                pep_embed.add_field(name=field, value=pep_header[field]) + +        return pep_embed + +    @pep_cache(arg_offset=1) +    async def get_pep_embed(self, pep_nr: int) -> Tuple[Embed, bool]: +        """Fetch, generate and return PEP embed. Second item of return tuple show does getting success.""" +        response = await self.bot.http_session.get(self.peps[pep_nr]) + +        if response.status == 200: +            log.trace(f"PEP {pep_nr} found") +            pep_content = await response.text() + +            # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179 +            pep_header = HeaderParser().parse(StringIO(pep_content)) +            return self.generate_pep_embed(pep_header, pep_nr), True +        else: +            log.trace( +                f"The user requested PEP {pep_nr}, but the response had an unexpected status code: {response.status}." +            ) +            return Embed( +                title="Unexpected error", +                description="Unexpected HTTP error during PEP search. Please let us know.", +                colour=Colour.red() +            ), False +    # endregion  def setup(bot: Bot) -> None: diff --git a/bot/resources/tags/codeblock.md b/bot/resources/tags/codeblock.md index a28ae397b..8d48bdf06 100644 --- a/bot/resources/tags/codeblock.md +++ b/bot/resources/tags/codeblock.md @@ -1,17 +1,7 @@ -Discord has support for Markdown, which allows you to post code with full syntax highlighting. Please use these whenever you paste code, as this helps improve the legibility and makes it easier for us to help you. +Here's how to format Python code on Discord: -To do this, use the following method: - -\```python +\```py  print('Hello world!')  \``` -Note:   -• **These are backticks, not quotes.** Backticks can usually be found on the tilde key.   -• You can also use py as the language instead of python   -• The language must be on the first line next to the backticks with **no** space between them   - -This will result in the following: -```py -print('Hello world!') -``` +**These are backticks, not quotes.** Check [this](https://superuser.com/questions/254076/how-do-i-type-the-tick-and-backtick-characters-on-windows/254077#254077) out if you can't find the backtick key. diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 60170a88f..13533a467 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,4 +1,4 @@ -from bot.utils.helpers import CogABCMeta, find_nth_occurrence, pad_base64 +from bot.utils.helpers import CogABCMeta, find_nth_occurrence, has_lines, pad_base64  from bot.utils.services import send_to_paste_service -__all__ = ['CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] +__all__ = ['CogABCMeta', 'find_nth_occurrence', 'has_lines', 'pad_base64', 'send_to_paste_service'] diff --git a/bot/utils/cache.py b/bot/utils/cache.py new file mode 100644 index 000000000..68ce15607 --- /dev/null +++ b/bot/utils/cache.py @@ -0,0 +1,41 @@ +import functools +from collections import OrderedDict +from typing import Any, Callable + + +class AsyncCache: +    """ +    LRU cache implementation for coroutines. + +    Once the cache exceeds the maximum size, keys are deleted in FIFO order. + +    An offset may be optionally provided to be applied to the coroutine's arguments when creating the cache key. +    """ + +    def __init__(self, max_size: int = 128): +        self._cache = OrderedDict() +        self._max_size = max_size + +    def __call__(self, arg_offset: int = 0) -> Callable: +        """Decorator for async cache.""" + +        def decorator(function: Callable) -> Callable: +            """Define the async cache decorator.""" + +            @functools.wraps(function) +            async def wrapper(*args) -> Any: +                """Decorator wrapper for the caching logic.""" +                key = args[arg_offset:] + +                if key not in self._cache: +                    if len(self._cache) > self._max_size: +                        self._cache.popitem(last=False) + +                    self._cache[key] = await function(*args) +                return self._cache[key] +            return wrapper +        return decorator + +    def clear(self) -> None: +        """Clear cache instance.""" +        self._cache.clear() diff --git a/bot/utils/channel.py b/bot/utils/channel.py new file mode 100644 index 000000000..6bf70bfde --- /dev/null +++ b/bot/utils/channel.py @@ -0,0 +1,49 @@ +import logging + +import discord + +from bot import constants +from bot.constants import Categories + +log = logging.getLogger(__name__) + + +def is_help_channel(channel: discord.TextChannel) -> bool: +    """Return True if `channel` is in one of the help categories (excluding dormant).""" +    log.trace(f"Checking if #{channel} is a help channel.") +    categories = (Categories.help_available, Categories.help_in_use) + +    return any(is_in_category(channel, category) for category in categories) + + +def is_mod_channel(channel: discord.TextChannel) -> bool: +    """True if `channel` is considered a mod channel.""" +    if channel.id in constants.MODERATION_CHANNELS: +        log.trace(f"Channel #{channel} is a configured mod channel") +        return True + +    elif any(is_in_category(channel, category) for category in constants.MODERATION_CATEGORIES): +        log.trace(f"Channel #{channel} is in a configured mod category") +        return True + +    else: +        log.trace(f"Channel #{channel} is not a mod channel") +        return False + + +def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: +    """Return True if `channel` is within a category with `category_id`.""" +    return getattr(channel, "category_id", None) == category_id + + +async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel: +    """Attempt to get or fetch a channel and return it.""" +    log.trace(f"Getting the channel {channel_id}.") + +    channel = client.get_channel(channel_id) +    if not channel: +        log.debug(f"Channel {channel_id} is not in cache; fetching from API.") +        channel = await client.fetch_channel(channel_id) + +    log.trace(f"Channel #{channel} ({channel_id}) retrieved.") +    return channel diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py index d9b60af07..3501a3933 100644 --- a/bot/utils/helpers.py +++ b/bot/utils/helpers.py @@ -18,6 +18,15 @@ def find_nth_occurrence(string: str, substring: str, n: int) -> Optional[int]:      return index +def has_lines(string: str, count: int) -> bool: +    """Return True if `string` has at least `count` lines.""" +    # Benchmarks show this is significantly faster than using str.count("\n") or a for loop & break. +    split = string.split("\n", count - 1) + +    # Make sure the last part isn't empty, which would happen if there was a final newline. +    return split[-1] and len(split) == count + +  def pad_base64(data: str) -> str:      """Return base64 `data` with padding characters to ensure its length is a multiple of 4."""      return data + "=" * (-len(data) % 4) diff --git a/config-default.yml b/config-default.yml index 4f7b1e217..2afdcd594 100644 --- a/config-default.yml +++ b/config-default.yml @@ -119,6 +119,7 @@ style:          voice_state_green: "https://cdn.discordapp.com/emojis/656899770094452754.png"          voice_state_red: "https://cdn.discordapp.com/emojis/656899769905709076.png" +  guild:      id: 267624335836053506      invite: "https://discord.gg/python" @@ -127,7 +128,9 @@ guild:          help_available:                     691405807388196926          help_in_use:                        696958401460043776          help_dormant:                       691405908919451718 -        modmail:                            714494672835444826 +        modmail:            &MODMAIL        714494672835444826 +        logs:               &LOGS           468520609152892958 +        voice:                              356013253765234688      channels:          # Public announcement and news channels @@ -145,8 +148,8 @@ guild:          dev_log:            &DEV_LOG        622895325144940554          # Discussion -        meta:               429409067623251969 -        python_discussion:  267624335836053506 +        meta:                               429409067623251969 +        python_discussion:  &PY_DISCUSSION  267624335836053506          # Python Help: Available          how_to_get_help:    704250143020417084 @@ -169,6 +172,7 @@ guild:          bot_commands:       &BOT_CMD        267659945086812160          esoteric:                           470884583684964352          verification:                       352442727016693763 +        voice_gate:                         764802555427029012          # Staff          admins:             &ADMINS         365960823622991872 @@ -178,7 +182,7 @@ guild:          incidents:                          714214212200562749          incidents_archive:                  720668923636351037          mods:               &MODS           305126844661760000 -        mod_alerts:         &MOD_ALERTS     473092532147060736 +        mod_alerts:                         473092532147060736          mod_spam:           &MOD_SPAM       620607373828030464          organisation:       &ORGANISATION   551789653284356126          staff_lounge:       &STAFF_LOUNGE   464905259261755392 @@ -191,6 +195,8 @@ guild:          # Voice          code_help_voice:                    755154969761677312 +        code_help_voice_2:                  766330079135268884 +        voice_chat:                         412357430186344448          admins_voice:       &ADMINS_VOICE   500734494840717332          staff_voice:        &STAFF_VOICE    412375055910043655 @@ -198,10 +204,13 @@ guild:          big_brother_logs:   &BB_LOGS        468507907357409333          talent_pool:        &TALENT_POOL    534321732593647616 +    moderation_categories: +        - *MODMAIL +        - *LOGS +      moderation_channels:          - *ADMINS          - *ADMIN_SPAM -        - *MOD_ALERTS          - *MODS          - *MOD_SPAM @@ -225,9 +234,11 @@ guild:          muted:              &MUTED_ROLE         277914926603829249          partners:                               323426753857191936          python_community:   &PY_COMMUNITY_ROLE  458226413825294336 +        sprinters:          &SPRINTERS          758422482289426471          unverified:                             739794855945044069          verified:                               352427296948486144  # @Developers on PyDis +        voice_verified:                         764802720779337729          # Staff          admins:             &ADMINS_ROLE    267628507062992896 @@ -261,6 +272,7 @@ guild:          reddit:                             635408384794951680          talent_pool:                        569145364800602132 +  filter:      # What do we filter?      filter_zalgo:          false @@ -298,6 +310,7 @@ filter:          - *OWNERS_ROLE          - *HELPERS_ROLE          - *PY_COMMUNITY_ROLE +        - *SPRINTERS  keys: @@ -326,6 +339,7 @@ urls:      bot_avatar:      "https://raw.githubusercontent.com/discord-python/branding/master/logos/logo_circle/logo_circle.png"      github_bot_repo: "https://github.com/python-discord/bot" +  anti_spam:      # Clean messages that violate a rule.      clean_offending: true @@ -394,6 +408,23 @@ big_brother:      header_message_limit: 15 +code_block: +    # The channels in which code blocks will be detected. They are not subject to a cooldown. +    channel_whitelist: +        - *BOT_CMD + +    # The channels which will be affected by a cooldown. These channels are also whitelisted. +    cooldown_channels: +        - *PY_DISCUSSION + +    # Sending instructions triggers a cooldown on a per-channel basis. +    # More instruction messages will not be sent in the same channel until the cooldown has elapsed. +    cooldown_seconds: 300 + +    # The minimum amount of lines a message or code block must have for instructions to be sent. +    minimum_lines: 4 + +  free:      # Seconds to elapse for a channel      # to be considered inactive. @@ -442,10 +473,12 @@ help_channels:      notify_roles:          - *HELPERS_ROLE +  redirect_output:      delete_invocation: true      delete_delay: 15 +  duck_pond:      threshold: 4      channel_blacklist: @@ -461,11 +494,13 @@ duck_pond:          - *MOD_ANNOUNCEMENTS          - *ADMIN_ANNOUNCEMENTS +  python_news:      mail_lists:          - 'python-ideas'          - 'python-announce-list'          - 'pypi-announce' +        - 'python-dev'      channel: *PYNEWS_CHANNEL      webhook: *PYNEWS_WEBHOOK @@ -482,5 +517,12 @@ verification:      kick_confirmation_threshold: 0.01  # 1% +voice_gate: +    minimum_days_verified: 3  # How many days the user must have been verified for +    minimum_messages: 50  # How many messages a user must have to be eligible for voice +    bot_message_delete_delay: 10  # Seconds before deleting bot's response in Voice Gate +    minimum_activity_blocks: 3  # Number of 10 minute blocks during which a user must have been active + +  config:      required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index cff7d33d6..8be5aac0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -41,6 +41,7 @@ services:        - postgres      environment:        DATABASE_URL: postgres://pysite:pysite@postgres:5432/pysite +      METRICITY_DB_URL: postgres://pysite:pysite@postgres:5432/metricity        SECRET_KEY: suitable-for-development-only        STATIC_ROOT: /var/www/static diff --git a/tests/_autospec.py b/tests/_autospec.py new file mode 100644 index 000000000..ee2fc1973 --- /dev/null +++ b/tests/_autospec.py @@ -0,0 +1,64 @@ +import contextlib +import functools +import unittest.mock +from typing import Callable + + [email protected](unittest.mock._patch.decoration_helper) +def _decoration_helper(self, patched, args, keywargs): +    """Skips adding patchings as args if their `dont_pass` attribute is True.""" +    # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added. +    extra_args = [] +    with contextlib.ExitStack() as exit_stack: +        for patching in patched.patchings: +            arg = exit_stack.enter_context(patching) +            if not getattr(patching, "dont_pass", False): +                # Only add the patching as an arg if dont_pass is False. +                if patching.attribute_name is not None: +                    keywargs.update(arg) +                elif patching.new is unittest.mock.DEFAULT: +                    extra_args.append(arg) + +        args += tuple(extra_args) +        yield args, keywargs + + [email protected](unittest.mock._patch.copy) +def _copy(self): +    """Copy the `dont_pass` attribute along with the standard copy operation.""" +    patcher_copy = _copy.original(self) +    patcher_copy.dont_pass = getattr(self, "dont_pass", False) +    return patcher_copy + + +# Monkey-patch the patcher class :) +_copy.original = unittest.mock._patch.copy +unittest.mock._patch.copy = _copy +unittest.mock._patch.decoration_helper = _decoration_helper + + +def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable: +    """ +    Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True. + +    If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object. +    """ +    # Caller's kwargs should take priority and overwrite the defaults. +    kwargs = dict(spec_set=True, autospec=True) +    kwargs.update(patch_kwargs) + +    # Import the target if it's a string. +    # This is to support both object and string targets like patch.multiple. +    if type(target) is str: +        target = unittest.mock._importer(target) + +    def decorator(func): +        for attribute in attributes: +            patcher = unittest.mock.patch.object(target, attribute, **kwargs) +            if not pass_mocks: +                # A custom attribute to keep track of which patchings should be skipped. +                patcher.dont_pass = True +            func = patcher(func) +        return func +    return decorator diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py index be1b649e1..bf557a484 100644 --- a/tests/bot/exts/moderation/infraction/test_infractions.py +++ b/tests/bot/exts/moderation/infraction/test_infractions.py @@ -1,7 +1,8 @@  import textwrap  import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch +from bot.constants import Event  from bot.exts.moderation.infraction.infractions import Infractions  from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole @@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):          self.cog.apply_infraction.assert_awaited_once_with(              self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value          ) + + +@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456) +class VoiceBanTests(unittest.IsolatedAsyncioTestCase): +    """Tests for voice ban related functions and commands.""" + +    def setUp(self): +        self.bot = MockBot() +        self.mod = MockMember(top_role=10) +        self.user = MockMember(top_role=1, roles=[MockRole(id=123456)]) +        self.guild = MockGuild() +        self.ctx = MockContext(bot=self.bot, author=self.mod) +        self.cog = Infractions(self.bot) + +    async def test_permanent_voice_ban(self): +        """Should call voice ban applying function without expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar") + +    async def test_temporary_voice_ban(self): +        """Should call voice ban applying function with expiry.""" +        self.cog.apply_voice_ban = AsyncMock() +        self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar")) +        self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz") + +    async def test_voice_unban(self): +        """Should call infraction pardoning function.""" +        self.cog.pardon_infraction = AsyncMock() +        self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user)) +        self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock): +        """Should return early when user already have Voice Ban infraction.""" +        get_active_infraction.return_value = {"foo": "bar"} +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban") +        post_infraction_mock.assert_not_awaited() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock): +        """Should return early when posting infraction fails.""" +        self.cog.mod_log.ignore = MagicMock() +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        post_infraction_mock.assert_awaited_once() +        self.cog.mod_log.ignore.assert_not_called() + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock): +        """Should pass all kwargs passed to apply_voice_ban to post_infraction.""" +        get_active_infraction.return_value = None +        # We don't want that this continue yet +        post_infraction_mock.return_value = None +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23)) +        post_infraction_mock.assert_awaited_once_with( +            self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23 +        ) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id) + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock): +        """Should ignore Voice Verified role removing.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar")) +        self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar") +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction") +    @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction") +    async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock): +        """Should truncate reason for voice ban.""" +        self.cog.mod_log.ignore = MagicMock() +        self.cog.apply_infraction = AsyncMock() +        self.user.remove_roles = MagicMock(return_value="my_return_value") + +        get_active_infraction.return_value = None +        post_infraction_mock.return_value = {"foo": "bar"} + +        self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000)) +        self.user.remove_roles.assert_called_once_with( +            self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...") +        ) +        self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value") + +    async def test_voice_unban_user_not_found(self): +        """Should include info to return dict when user was not found from guild.""" +        self.guild.get_member.return_value = None +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, {"Info": "User was not found in the guild."}) + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = True +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "Sent" +        }) +        notify_pardon_mock.assert_awaited_once() + +    @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon") +    @patch("bot.exts.moderation.infraction.infractions.format_user") +    async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock): +        """Should add role back with ignoring, notify user and return log dictionary..""" +        self.guild.get_member.return_value = self.user +        notify_pardon_mock.return_value = False +        format_user_mock.return_value = "my-user" + +        result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar") +        self.assertEqual(result, { +            "Member": "my-user", +            "DM": "**Failed**" +        }) +        notify_pardon_mock.assert_awaited_once() diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py index 3c2d52ae0..104293d8e 100644 --- a/tests/bot/exts/moderation/test_silence.py +++ b/tests/bot/exts/moderation/test_silence.py @@ -1,23 +1,49 @@ +import asyncio  import unittest +from datetime import datetime, timezone  from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock +from async_rediscache import RedisSession  from discord import PermissionOverwrite -from bot.constants import Channels, Emojis, Guild, Roles -from bot.exts.moderation.silence import Silence, SilenceNotifier -from tests.helpers import MockBot, MockContext, MockTextChannel +from bot.constants import Channels, Guild, Roles +from bot.exts.moderation import silence +from tests.helpers import MockBot, MockContext, MockTextChannel, autospec + +redis_session = None +redis_loop = asyncio.get_event_loop() + + +def setUpModule():  # noqa: N802 +    """Create and connect to the fakeredis session.""" +    global redis_session +    redis_session = RedisSession(use_fakeredis=True) +    redis_loop.run_until_complete(redis_session.connect()) + + +def tearDownModule():  # noqa: N802 +    """Close the fakeredis session.""" +    if redis_session: +        redis_loop.run_until_complete(redis_session.close()) + + +# Have to subclass it because builtins can't be patched. +class PatchedDatetime(datetime): +    """A datetime object with a mocked now() function.""" + +    now = mock.create_autospec(datetime, "now")  class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):      def setUp(self) -> None:          self.alert_channel = MockTextChannel() -        self.notifier = SilenceNotifier(self.alert_channel) +        self.notifier = silence.SilenceNotifier(self.alert_channel)          self.notifier.stop = self.notifier_stop_mock = Mock()          self.notifier.start = self.notifier_start_mock = Mock()      def test_add_channel_adds_channel(self): -        """Channel in FirstHash with current loop is added to internal set.""" +        """Channel is added to `_silenced_channels` with the current loop."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.add_channel(channel) @@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):          self.notifier_start_mock.assert_not_called()      def test_remove_channel_removes_channel(self): -        """Channel in FirstHash is removed from `_silenced_channels`.""" +        """Channel is removed from `_silenced_channels`."""          channel = Mock()          with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:              self.notifier.remove_channel(channel) @@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):              with self.subTest(current_loop=current_loop):                  with mock.patch.object(self.notifier, "_current_loop", new=current_loop):                      await self.notifier._notifier() -                self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ") +                self.alert_channel.send.assert_called_once_with( +                    f"<@&{Roles.moderators}> currently silenced channels: " +                )              self.alert_channel.send.reset_mock()      async def test_notifier_skips_alert(self): @@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):                      self.alert_channel.send.assert_not_called() -class SilenceTests(unittest.IsolatedAsyncioTestCase): +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceCogTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the general functionality of the Silence cog.""" + +    @autospec(silence, "Scheduler", pass_mocks=False)      def setUp(self) -> None:          self.bot = MockBot() -        self.cog = Silence(self.bot) -        self.ctx = MockContext() -        self.cog._verified_role = None -        # Set event so command callbacks can continue. -        self.cog._get_instance_vars_event.set() +        self.cog = silence.Silence(self.bot) -    async def test_instance_vars_got_guild(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_guild(self):          """Bot got guild after it became available.""" -        await self.cog._get_instance_vars() -        self.bot.wait_until_guild_available.assert_called_once() +        await self.cog._async_init() +        self.bot.wait_until_guild_available.assert_awaited_once()          self.bot.get_guild.assert_called_once_with(Guild.id) -    async def test_instance_vars_got_role(self): +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_role(self):          """Got `Roles.verified` role from guild.""" -        await self.cog._get_instance_vars()          guild = self.bot.get_guild() -        guild.get_role.assert_called_once_with(Roles.verified) +        guild.get_role.side_effect = lambda id_: Mock(id=id_) -    async def test_instance_vars_got_channels(self): +        await self.cog._async_init() +        self.assertEqual(self.cog._verified_role.id, Roles.verified) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_got_channels(self):          """Got channels from bot.""" -        await self.cog._get_instance_vars() -        self.bot.get_channel.called_once_with(Channels.mod_alerts) -        self.bot.get_channel.called_once_with(Channels.mod_log) +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts) -    @mock.patch("bot.exts.moderation.silence.SilenceNotifier") -    async def test_instance_vars_got_notifier(self, notifier): +    @autospec(silence, "SilenceNotifier") +    async def test_async_init_got_notifier(self, notifier):          """Notifier was started with channel.""" -        mod_log = MockTextChannel() -        self.bot.get_channel.side_effect = (None, mod_log) -        await self.cog._get_instance_vars() -        notifier.assert_called_once_with(mod_log) -        self.bot.get_channel.side_effect = None - -    async def test_silence_sent_correct_discord_message(self): -        """Check if proper message was sent when called with duration in channel with previous state.""" +        self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_) + +        await self.cog._async_init() +        notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log)) +        self.assertEqual(self.cog.notifier, notifier.return_value) + +    @autospec(silence, "SilenceNotifier", pass_mocks=False) +    async def test_async_init_rescheduled(self): +        """`_reschedule_` coroutine was awaited.""" +        self.cog._reschedule = mock.create_autospec(self.cog._reschedule) +        await self.cog._async_init() +        self.cog._reschedule.assert_awaited_once_with() + +    def test_cog_unload_cancelled_tasks(self): +        """The init task was cancelled.""" +        self.cog._init_task = asyncio.Future() +        self.cog.cog_unload() + +        # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda. +        self.assertTrue(self.cog._init_task.cancelled()) + +    @autospec("discord.ext.commands", "has_any_role") +    @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3)) +    async def test_cog_check(self, role_check): +        """Role check was called with `MODERATION_ROLES`""" +        ctx = MockContext() +        role_check.return_value.predicate = mock.AsyncMock() + +        await self.cog.cog_check(ctx) +        role_check.assert_called_once_with(*(1, 2, 3)) +        role_check.return_value.predicate.assert_awaited_once_with(ctx) + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class RescheduleTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the rescheduling of cached unsilences.""" + +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self): +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper) + +        with mock.patch.object(self.cog, "_reschedule", autospec=True): +            asyncio.run(self.cog._async_init())  # Populate instance attributes. + +    async def test_skipped_missing_channel(self): +        """Did nothing because the channel couldn't be retrieved.""" +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)] +        self.bot.get_channel.return_value = None + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_added_permanent_to_notifier(self): +        """Permanently silenced channels were added to the notifier.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)] + +        await self.cog._reschedule() + +        self.cog.notifier.add_channel.assert_any_call(channels[0]) +        self.cog.notifier.add_channel.assert_any_call(channels[1]) + +        self.cog._unsilence_wrapper.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    async def test_unsilenced_expired(self): +        """Unsilenced expired silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)] + +        await self.cog._reschedule() + +        self.cog._unsilence_wrapper.assert_any_call(channels[0]) +        self.cog._unsilence_wrapper.assert_any_call(channels[1]) + +        self.cog.notifier.add_channel.assert_not_called() +        self.cog.scheduler.schedule_later.assert_not_called() + +    @mock.patch.object(silence, "datetime", new=PatchedDatetime) +    async def test_rescheduled_active(self): +        """Rescheduled active silences.""" +        channels = [MockTextChannel(id=123), MockTextChannel(id=456)] +        self.bot.get_channel.side_effect = channels +        self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)] +        silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc) + +        self.cog._unsilence_wrapper = mock.MagicMock() +        unsilence_return = self.cog._unsilence_wrapper.return_value + +        await self.cog._reschedule() + +        # Yuck. +        calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)] +        self.cog.scheduler.schedule_later.assert_has_calls(calls) + +        unsilence_calls = [mock.call(channel) for channel in channels] +        self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls) + +        self.cog.notifier.add_channel.assert_not_called() + + +@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False) +class SilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the silence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot() +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        # Avoid unawaited coroutine warnings. +        self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close() + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command."""          test_cases = ( -            (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,), -            (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,), -            (5, f"{Emojis.cross_mark} current channel is already silenced.", False,), +            (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,), +            (None, silence.MSG_SILENCE_PERMANENT, True,), +            (5, silence.MSG_SILENCE_FAIL, False,),          ) -        for duration, result_message, _silence_patch_return in test_cases: -            with self.subTest( -                silence_duration=duration, -                result_message=result_message, -                starting_unsilenced_state=_silence_patch_return -            ): -                with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return): -                    await self.cog.silence(self.cog, self.ctx, duration) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_unsilence_sent_correct_discord_message(self): -        """Check if proper message was sent when unsilencing channel.""" -        test_cases = ( -            (True, f"{Emojis.check_mark} unsilenced current channel."), -            (False, f"{Emojis.cross_mark} current channel was not silenced.") +        for duration, message, was_silenced in test_cases: +            ctx = MockContext() +            with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced): +                with self.subTest(was_silenced=was_silenced, message=message, duration=duration): +                    await self.cog.silence.callback(self.cog, ctx, duration) +                    ctx.send.assert_called_once_with(message) + +    async def test_skipped_already_silenced(self): +        """Permissions were not set and `False` was returned for an already silenced channel.""" +        subtests = ( +            (False, PermissionOverwrite(send_messages=False, add_reactions=False)), +            (True, PermissionOverwrite(send_messages=True, add_reactions=True)), +            (True, PermissionOverwrite(send_messages=False, add_reactions=False)),          ) -        for _unsilence_patch_return, result_message in test_cases: -            with self.subTest( -                starting_silenced_state=_unsilence_patch_return, -                result_message=result_message -            ): -                with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return): -                    await self.cog.unsilence(self.cog, self.ctx) -                    self.ctx.send.assert_called_once_with(result_message) -            self.ctx.reset_mock() - -    async def test_silence_private_for_false(self): -        """Permissions are not set and `False` is returned in an already silenced channel.""" -        perm_overwrite = Mock(send_messages=False) -        channel = Mock(overwrites_for=Mock(return_value=perm_overwrite)) - -        self.assertFalse(await self.cog._silence(channel, True, None)) -        channel.set_permissions.assert_not_called() -    async def test_silence_private_silenced_channel(self): -        """Channel had `send_message` permissions revoked.""" -        channel = MockTextChannel() -        self.assertTrue(await self.cog._silence(channel, False, None)) -        channel.set_permissions.assert_called_once() -        self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages']) +        for contains, overwrite in subtests: +            with self.subTest(contains=contains, overwrite=overwrite): +                self.cog.scheduler.__contains__.return_value = contains +                channel = MockTextChannel() +                channel.overwrites_for.return_value = overwrite + +                self.assertFalse(await self.cog._set_silence_overwrites(channel)) +                channel.set_permissions.assert_not_called() + +    async def test_silenced_channel(self): +        """Channel had `send_message` and `add_reactions` permissions revoked for verified role.""" +        self.assertTrue(await self.cog._set_silence_overwrites(self.channel)) +        self.assertFalse(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite +        ) -    async def test_silence_private_preserves_permissions(self): -        """Previous permissions were preserved when channel was silenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite() -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._silence(channel, False, None) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    async def test_silence_private_notifier(self): -        """Channel should be added to notifier with `persistent` set to `True`, and the other way around.""" -        channel = MockTextChannel() -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=True): -                await self.cog._silence(channel, True, None) -                self.cog.notifier.add_channel.assert_called_once() - -        with mock.patch.object(self.cog, "notifier", create=True): -            with self.subTest(persistent=False): -                await self.cog._silence(channel, False, None) -                self.cog.notifier.add_channel.assert_not_called() - -    async def test_silence_private_added_muted_channel(self): -        """Channel was added to `muted_channels` on silence.""" +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed.""" +        prev_overwrite_dict = dict(self.overwrite) +        await self.cog._set_silence_overwrites(self.channel) +        new_overwrite_dict = dict(self.overwrite) + +        # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method. +        del prev_overwrite_dict['send_messages'] +        del prev_overwrite_dict['add_reactions'] +        del new_overwrite_dict['send_messages'] +        del new_overwrite_dict['add_reactions'] + +        self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) + +    async def test_temp_not_added_to_notifier(self): +        """Channel was not added to notifier if a duration was set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_indefinite_added_to_notifier(self): +        """Channel was added to notifier if a duration was not set for the silence.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True): +            await self.cog.silence.callback(self.cog, MockContext(), None) +            self.cog.notifier.add_channel.assert_called_once() + +    async def test_silenced_not_added_to_notifier(self): +        """Channel was not added to the notifier if it was already silenced.""" +        with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False): +            await self.cog.silence.callback(self.cog, MockContext(), 15) +            self.cog.notifier.add_channel.assert_not_called() + +    async def test_cached_previous_overwrites(self): +        """Channel's previous overwrites were cached.""" +        overwrite_json = '{"send_messages": true, "add_reactions": false}' +        await self.cog._set_silence_overwrites(self.channel) +        self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json) + +    @autospec(silence, "datetime") +    async def test_cached_unsilence_time(self, datetime_mock): +        """The UTC POSIX timestamp for the unsilence was cached.""" +        now_timestamp = 100 +        duration = 15 +        timestamp = now_timestamp + duration * 60 +        datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc) + +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, duration) + +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp) +        datetime_mock.now.assert_called_once_with(tz=timezone.utc)  # Ensure it's using an aware dt. + +    async def test_cached_indefinite_time(self): +        """A value of -1 was cached for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1) + +    async def test_scheduled_task(self): +        """An unsilence task was scheduled.""" +        ctx = MockContext(channel=self.channel, invoke=mock.MagicMock()) + +        await self.cog.silence.callback(self.cog, ctx, 5) + +        args = (300, ctx.channel.id, ctx.invoke.return_value) +        self.cog.scheduler.schedule_later.assert_called_once_with(*args) +        ctx.invoke.assert_called_once_with(self.cog.unsilence) + +    async def test_permanent_not_scheduled(self): +        """A task was not scheduled for a permanent silence.""" +        ctx = MockContext(channel=self.channel) +        await self.cog.silence.callback(self.cog, ctx, None) +        self.cog.scheduler.schedule_later.assert_not_called() + + +@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False) +class UnsilenceTests(unittest.IsolatedAsyncioTestCase): +    """Tests for the unsilence command and its related helper methods.""" + +    @autospec(silence.Silence, "_reschedule", pass_mocks=False) +    @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False) +    def setUp(self) -> None: +        self.bot = MockBot(get_channel=lambda _: MockTextChannel()) +        self.cog = silence.Silence(self.bot) +        self.cog._init_task = asyncio.Future() +        self.cog._init_task.set_result(None) + +        overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True) +        self.cog.previous_overwrites = overwrites_cache + +        asyncio.run(self.cog._async_init())  # Populate instance attributes. + +        self.cog.scheduler.__contains__.return_value = True +        overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}' +        self.channel = MockTextChannel() +        self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False) +        self.channel.overwrites_for.return_value = self.overwrite + +    async def test_sent_correct_message(self): +        """Appropriate failure/success message was sent by the command.""" +        unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True) +        test_cases = ( +            (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)), +            (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)), +        ) +        for was_unsilenced, message, overwrite in test_cases: +            ctx = MockContext() +            with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite): +                with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced): +                    ctx.channel.overwrites_for.return_value = overwrite +                    await self.cog.unsilence.callback(self.cog, ctx) +                    ctx.channel.send.assert_called_once_with(message) + +    async def test_skipped_already_unsilenced(self): +        """Permissions were not set and `False` was returned for an already unsilenced channel.""" +        self.cog.scheduler.__contains__.return_value = False +        self.cog.previous_overwrites.get.return_value = None          channel = MockTextChannel() -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._silence(channel, False, None) -        muted_channels.add.assert_called_once_with(channel) -    async def test_unsilence_private_for_false(self): -        """Permissions are not set and `False` is returned in an unsilenced channel.""" -        channel = Mock()          self.assertFalse(await self.cog._unsilence(channel))          channel.set_permissions.assert_not_called() -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_unsilenced_channel(self, _): -        """Channel had `send_message` permissions restored""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        self.assertTrue(await self.cog._unsilence(channel)) -        channel.set_permissions.assert_called_once() -        self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages']) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_notifier(self, notifier): -        """Channel was removed from `notifier` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        await self.cog._unsilence(channel) -        notifier.remove_channel.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_removed_muted_channel(self, _): -        """Channel was removed from `muted_channels` on unsilence.""" -        perm_overwrite = MagicMock(send_messages=False) -        channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite)) -        with mock.patch.object(self.cog, "muted_channels") as muted_channels: -            await self.cog._unsilence(channel) -        muted_channels.discard.assert_called_once_with(channel) - -    @mock.patch.object(Silence, "notifier", create=True) -    async def test_unsilence_private_preserves_permissions(self, _): -        """Previous permissions were preserved when channel was unsilenced.""" -        channel = MockTextChannel() -        # Set up mock channel permission state. -        mock_permissions = PermissionOverwrite(send_messages=False) -        mock_permissions_dict = dict(mock_permissions) -        channel.overwrites_for.return_value = mock_permissions -        await self.cog._unsilence(channel) -        new_permissions = channel.set_permissions.call_args.kwargs -        # Remove 'send_messages' key because it got changed in the method. -        del new_permissions['send_messages'] -        del mock_permissions_dict['send_messages'] -        self.assertDictEqual(mock_permissions_dict, new_permissions) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    @mock.patch.object(Silence, "_mod_alerts_channel", create=True) -    def test_cog_unload_starts_task(self, alert_channel, asyncio_mock): -        """Task for sending an alert was created with present `muted_channels`.""" -        with mock.patch.object(self.cog, "muted_channels"): -            self.cog.cog_unload() -            alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ") -            asyncio_mock.create_task.assert_called_once_with(alert_channel.send()) - -    @mock.patch("bot.exts.moderation.silence.asyncio") -    def test_cog_unload_skips_task_start(self, asyncio_mock): -        """No task created with no channels.""" -        self.cog.cog_unload() -        asyncio_mock.create_task.assert_not_called() +    async def test_restored_overwrites(self): +        """Channel's `send_message` and `add_reactions` overwrites were restored.""" +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) -    @mock.patch("discord.ext.commands.has_any_role") -    @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3)) -    async def test_cog_check(self, role_check): -        """Role check is called with `MODERATION_ROLES`""" -        role_check.return_value.predicate = mock.AsyncMock() -        await self.cog.cog_check(self.ctx) -        role_check.assert_called_once_with(*(1, 2, 3)) -        role_check.return_value.predicate.assert_awaited_once_with(self.ctx) +        # Recall that these values are determined by the fixture. +        self.assertTrue(self.overwrite.send_messages) +        self.assertFalse(self.overwrite.add_reactions) + +    async def test_cache_miss_used_default_overwrites(self): +        """Both overwrites were set to None due previous values not being found in the cache.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.channel.set_permissions.assert_awaited_once_with( +            self.cog._verified_role, +            overwrite=self.overwrite, +        ) + +        self.assertIsNone(self.overwrite.send_messages) +        self.assertIsNone(self.overwrite.add_reactions) + +    async def test_cache_miss_sent_mod_alert(self): +        """A message was sent to the mod alerts channel.""" +        self.cog.previous_overwrites.get.return_value = None + +        await self.cog._unsilence(self.channel) +        self.cog._mod_alerts_channel.send.assert_awaited_once() + +    async def test_removed_notifier(self): +        """Channel was removed from `notifier`.""" +        await self.cog._unsilence(self.channel) +        self.cog.notifier.remove_channel.assert_called_once_with(self.channel) + +    async def test_deleted_cached_overwrite(self): +        """Channel was deleted from the overwrites cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id) + +    async def test_deleted_cached_time(self): +        """Channel was deleted from the timestamp cache.""" +        await self.cog._unsilence(self.channel) +        self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id) + +    async def test_cancelled_task(self): +        """The scheduled unsilence task should be cancelled.""" +        await self.cog._unsilence(self.channel) +        self.cog.scheduler.cancel.assert_called_once_with(self.channel.id) + +    async def test_preserved_other_overwrites(self): +        """Channel's other unrelated overwrites were not changed, including cache misses.""" +        for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None): +            with self.subTest(overwrite_json=overwrite_json): +                self.cog.previous_overwrites.get.return_value = overwrite_json + +                prev_overwrite_dict = dict(self.overwrite) +                await self.cog._unsilence(self.channel) +                new_overwrite_dict = dict(self.overwrite) + +                # Remove these keys because they were modified by the unsilence. +                del prev_overwrite_dict['send_messages'] +                del prev_overwrite_dict['add_reactions'] +                del new_overwrite_dict['send_messages'] +                del new_overwrite_dict['add_reactions'] + +                self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 6601fad2c..9a42d0610 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -52,6 +52,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):              ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),              ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),              ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), +            ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'), +            ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', +             'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'), +            ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', +             'print("How\'s it going?")', 'code block preceded by inline code'), +            ('`print("Hello world!")`\ntext\n`print("Hello world!")`', +             'print("Hello world!")', 'one inline code block of two')          )          for case, expected, testname in cases:              with self.subTest(msg=f'Extract code from {testname}.'): diff --git a/tests/helpers.py b/tests/helpers.py index e47fdf28f..870f66197 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,7 +5,7 @@ import itertools  import logging  import unittest.mock  from asyncio import AbstractEventLoop -from typing import Callable, Iterable, Optional +from typing import Iterable, Optional  import discord  from aiohttp import ClientSession @@ -14,6 +14,7 @@ from discord.ext.commands import Context  from bot.api import APIClient  from bot.async_stats import AsyncStatsClient  from bot.bot import Bot +from tests._autospec import autospec  # noqa: F401 other modules import it via this module  for logger in logging.Logger.manager.loggerDict.values(): @@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():      logger.setLevel(logging.CRITICAL) -def autospec(target, *attributes: str, **kwargs) -> Callable: -    """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.""" -    # Caller's kwargs should take priority and overwrite the defaults. -    kwargs = {'spec_set': True, 'autospec': True, **kwargs} - -    # Import the target if it's a string. -    # This is to support both object and string targets like patch.multiple. -    if type(target) is str: -        target = unittest.mock._importer(target) - -    def decorator(func): -        for attribute in attributes: -            patcher = unittest.mock.patch.object(target, attribute, **kwargs) -            func = patcher(func) -        return func -    return decorator - -  class HashableMixin(discord.mixins.EqualityComparable):      """      Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. | 
