aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--LICENSE-THIRD-PARTY72
-rw-r--r--README.md2
-rw-r--r--bot/constants.py6
-rw-r--r--bot/exts/info/information.py49
-rw-r--r--bot/exts/moderation/infraction/_scheduler.py9
-rw-r--r--bot/exts/moderation/infraction/management.py10
-rw-r--r--bot/exts/moderation/silence.py202
-rw-r--r--bot/exts/moderation/voice_gate.py18
-rw-r--r--bot/exts/utils/reminders.py8
-rw-r--r--bot/exts/utils/snekbox.py45
-rw-r--r--bot/utils/channel.py16
-rw-r--r--config-default.yml13
-rw-r--r--tests/_autospec.py64
-rw-r--r--tests/bot/exts/moderation/test_silence.py587
-rw-r--r--tests/bot/exts/utils/test_snekbox.py7
-rw-r--r--tests/helpers.py21
16 files changed, 799 insertions, 330 deletions
diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY
index 3349d7c05..eacd9b952 100644
--- a/LICENSE-THIRD-PARTY
+++ b/LICENSE-THIRD-PARTY
@@ -1,14 +1,13 @@
-BSD 3-Clause License
-
+---------------------------------------------------------------------------------------------------
+ BSD 3-Clause License
Applies to:
-- _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL in bot/cogs/codeblock/parsing.py
-
-- Copyright (c) 2008-Present, IPython Development Team
-- Copyright (c) 2001-2007, Fernando Perez <[email protected]>
-- Copyright (c) 2001, Janko Hauser <[email protected]>
-- Copyright (c) 2001, Nathaniel Gray <[email protected]>
-
-All rights reserved.
+ - Copyright (c) 2008-Present, IPython Development Team
+ Copyright (c) 2001-2007, Fernando Perez <[email protected]>
+ Copyright (c) 2001, Janko Hauser <[email protected]>
+ Copyright (c) 2001, Nathaniel Gray <[email protected]>
+ All rights reserved.
+ - bot/exts/info/codeblock/_parsing.py: _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL
+---------------------------------------------------------------------------------------------------
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
@@ -34,3 +33,56 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+---------------------------------------------------------------------------------------------------
+ PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
+Applies to:
+ - Copyright © 2001-2020 Python Software Foundation. All rights reserved.
+ - tests/_autospec.py: _decoration_helper
+---------------------------------------------------------------------------------------------------
+
+1. This LICENSE AGREEMENT is between the Python Software Foundation
+("PSF"), and the Individual or Organization ("Licensee") accessing and
+otherwise using this software ("Python") in source or binary form and
+its associated documentation.
+
+2. Subject to the terms and conditions of this License Agreement, PSF hereby
+grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
+analyze, test, perform and/or display publicly, prepare derivative works,
+distribute, and otherwise use Python alone or in any derivative version,
+provided, however, that PSF's License Agreement and PSF's notice of copyright,
+i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
+2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation;
+All Rights Reserved" are retained in Python alone or in any derivative version
+prepared by Licensee.
+
+3. In the event Licensee prepares a derivative work that is based on
+or incorporates Python or any part thereof, and wants to make
+the derivative work available to others as provided herein, then
+Licensee hereby agrees to include in any such work a brief summary of
+the changes made to Python.
+
+4. PSF is making Python available to Licensee on an "AS IS"
+basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
+IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
+DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
+FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
+INFRINGE ANY THIRD PARTY RIGHTS.
+
+5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
+FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
+A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
+OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
+
+6. This License Agreement will automatically terminate upon a material
+breach of its terms and conditions.
+
+7. Nothing in this License Agreement shall be deemed to create any
+relationship of agency, partnership, or joint venture between PSF and
+Licensee. This License Agreement does not grant permission to use PSF
+trademarks or trade name in a trademark sense to endorse or promote
+products or services of Licensee, or any third party.
+
+8. By copying, installing or otherwise using Python, Licensee
+agrees to be bound by the terms and conditions of this License
+Agreement.
diff --git a/README.md b/README.md
index cae7c3454..b37ece296 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Python Utility Bot
-[![Discord](https://img.shields.io/static/v1?label=Python%20Discord&logo=discord&message=%3E60k%20members&color=%237289DA&logoColor=white)](https://discord.gg/2B963hn)
+[![Discord](https://img.shields.io/static/v1?label=Python%20Discord&logo=discord&message=%3E100k%20members&color=%237289DA&logoColor=white)](https://discord.gg/2B963hn)
[![Build Status](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)](https://dev.azure.com/python-discord/Python%20Discord/_build/latest?definitionId=1&branchName=master)
[![Tests](https://img.shields.io/azure-devops/tests/python-discord/Python%20Discord/1?compact_message)](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)
[![Coverage](https://img.shields.io/azure-devops/coverage/python-discord/Python%20Discord/1/master)](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)
diff --git a/bot/constants.py b/bot/constants.py
index c79de012c..23d5b4304 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -377,6 +377,7 @@ class Categories(metaclass=YAMLGetter):
help_in_use: int
help_dormant: int
modmail: int
+ voice: int
class Channels(metaclass=YAMLGetter):
@@ -425,6 +426,7 @@ class Channels(metaclass=YAMLGetter):
user_event_announcements: int
user_log: int
verification: int
+ voice_chat: int
voice_gate: int
voice_log: int
@@ -471,6 +473,7 @@ class Guild(metaclass=YAMLGetter):
id: int
invite: str # Discord invite, gets embedded in chat
moderation_channels: List[int]
+ moderation_categories: List[int]
moderation_roles: List[int]
modlog_blacklist: List[int]
reminder_whitelist: List[int]
@@ -639,6 +642,9 @@ STAFF_ROLES = Guild.staff_roles
# Channel combinations
MODERATION_CHANNELS = Guild.moderation_channels
+# Category combinations
+MODERATION_CATEGORIES = Guild.moderation_categories
+
# Bot replies
NEGATIVE_REPLIES = [
"Noooooo!!",
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 0f50138e7..5aaf85e5a 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -6,14 +6,16 @@ from collections import Counter, defaultdict
from string import Template
from typing import Any, Mapping, Optional, Tuple, Union
-from discord import ChannelType, Colour, Embed, Guild, Member, Message, Role, Status, utils
+from discord import ChannelType, Colour, Embed, Guild, Message, Role, Status, utils
from discord.abc import GuildChannel
from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role
from bot import constants
from bot.bot import Bot
+from bot.converters import FetchedMember
from bot.decorators import in_whitelist
from bot.pagination import LinePaginator
+from bot.utils.channel import is_mod_channel
from bot.utils.checks import cooldown_with_role_bypass, has_no_roles_check, in_whitelist_check
from bot.utils.time import time_since
@@ -191,7 +193,7 @@ class Information(Cog):
await ctx.send(embed=embed)
@command(name="user", aliases=["user_info", "member", "member_info"])
- async def user_info(self, ctx: Context, user: Member = None) -> None:
+ async def user_info(self, ctx: Context, user: FetchedMember = None) -> None:
"""Returns info about a user."""
if user is None:
user = ctx.author
@@ -206,12 +208,14 @@ class Information(Cog):
embed = await self.create_user_embed(ctx, user)
await ctx.send(embed=embed)
- async def create_user_embed(self, ctx: Context, user: Member) -> Embed:
+ async def create_user_embed(self, ctx: Context, user: FetchedMember) -> Embed:
"""Creates an embed containing information on the `user`."""
+ on_server = bool(ctx.guild.get_member(user.id))
+
created = time_since(user.created_at, max_units=3)
name = str(user)
- if user.nick:
+ if on_server and user.nick:
name = f"{user.nick} ({name})"
badges = []
@@ -220,8 +224,16 @@ class Information(Cog):
if is_set and (emoji := getattr(constants.Emojis, f"badge_{badge}", None)):
badges.append(emoji)
- joined = time_since(user.joined_at, max_units=3)
- roles = ", ".join(role.mention for role in user.roles[1:])
+ if on_server:
+ joined = time_since(user.joined_at, max_units=3)
+ roles = ", ".join(role.mention for role in user.roles[1:])
+ membership = textwrap.dedent(f"""
+ Joined: {joined}
+ Roles: {roles or None}
+ """).strip()
+ else:
+ roles = None
+ membership = "The user is not a member of the server"
fields = [
(
@@ -234,21 +246,12 @@ class Information(Cog):
),
(
"Member information",
- textwrap.dedent(f"""
- Joined: {joined}
- Roles: {roles or None}
- """).strip()
+ membership
),
]
- # Use getattr to future-proof for commands invoked via DMs.
- show_verbose = (
- ctx.channel.id in constants.MODERATION_CHANNELS
- or getattr(ctx.channel, "category_id", None) == constants.Categories.modmail
- )
-
# Show more verbose output in moderation channels for infractions and nominations
- if show_verbose:
+ if is_mod_channel(ctx.channel):
fields.append(await self.expanded_user_infraction_counts(user))
fields.append(await self.user_nomination_counts(user))
else:
@@ -268,13 +271,13 @@ class Information(Cog):
return embed
- async def basic_user_infraction_counts(self, member: Member) -> Tuple[str, str]:
+ async def basic_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""Gets the total and active infraction counts for the given `member`."""
infractions = await self.bot.api_client.get(
'bot/infractions',
params={
'hidden': 'False',
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
@@ -285,7 +288,7 @@ class Information(Cog):
return "Infractions", infraction_output
- async def expanded_user_infraction_counts(self, member: Member) -> Tuple[str, str]:
+ async def expanded_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""
Gets expanded infraction counts for the given `member`.
@@ -295,7 +298,7 @@ class Information(Cog):
infractions = await self.bot.api_client.get(
'bot/infractions',
params={
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
@@ -326,12 +329,12 @@ class Information(Cog):
return "Infractions", "\n".join(infraction_output)
- async def user_nomination_counts(self, member: Member) -> Tuple[str, str]:
+ async def user_nomination_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""Gets the active and historical nomination counts for the given `member`."""
nominations = await self.bot.api_client.get(
'bot/nominations',
params={
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py
index bba80afaf..bebade0ae 100644
--- a/bot/exts/moderation/infraction/_scheduler.py
+++ b/bot/exts/moderation/infraction/_scheduler.py
@@ -12,11 +12,12 @@ from discord.ext.commands import Context
from bot import constants
from bot.api import ResponseCodeError
from bot.bot import Bot
-from bot.constants import Colours, MODERATION_CHANNELS
+from bot.constants import Colours
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._utils import UserSnowflake
from bot.exts.moderation.modlog import ModLog
from bot.utils import messages, scheduling, time
+from bot.utils.channel import is_mod_channel
log = logging.getLogger(__name__)
@@ -136,11 +137,7 @@ class InfractionScheduler:
)
if reason:
end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})"
- elif ctx.channel.id not in MODERATION_CHANNELS:
- log.trace(
- f"Infraction #{id_} context is not in a mod channel; omitting infraction count and id."
- )
- else:
+ elif is_mod_channel(ctx.channel):
log.trace(f"Fetching total infraction count for {user}.")
infractions = await self.bot.api_client.get(
diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py
index cdab1a6c7..394f63da3 100644
--- a/bot/exts/moderation/infraction/management.py
+++ b/bot/exts/moderation/infraction/management.py
@@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions
from bot.exts.moderation.modlog import ModLog
from bot.pagination import LinePaginator
from bot.utils import messages, time
-from bot.utils.checks import in_whitelist_check
+from bot.utils.channel import is_mod_channel
log = logging.getLogger(__name__)
@@ -295,13 +295,7 @@ class ModManagement(commands.Cog):
"""Only allow moderators inside moderator channels to invoke the commands in this cog."""
checks = [
await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),
- in_whitelist_check(
- ctx,
- channels=constants.MODERATION_CHANNELS,
- categories=[constants.Categories.modmail],
- redirect=None,
- fail_silently=True,
- )
+ is_mod_channel(ctx.channel)
]
return all(checks)
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index ac0c1c85e..e6712b3b6 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -1,8 +1,11 @@
-import asyncio
+import json
import logging
from contextlib import suppress
+from datetime import datetime, timedelta, timezone
+from operator import attrgetter
from typing import Optional
+from async_rediscache import RedisCache
from discord import TextChannel
from discord.ext import commands, tasks
from discord.ext.commands import Context
@@ -10,10 +13,25 @@ from discord.ext.commands import Context
from bot.bot import Bot
from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles
from bot.converters import HushDurationConverter
+from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
+LOCK_NAMESPACE = "silence"
+
+MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced."
+MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely."
+MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)."
+
+MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced."
+MSG_UNSILENCE_MANUAL = (
+ f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were "
+ f"set manually or the cache was prematurely cleared. "
+ f"Please edit the overwrites manually to unsilence."
+)
+MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel."
+
class SilenceNotifier(tasks.Loop):
"""Loop notifier for posting notices to `alert_channel` containing added channels."""
@@ -56,25 +74,32 @@ class SilenceNotifier(tasks.Loop):
class Silence(commands.Cog):
"""Commands for stopping channel messages for `verified` role in a channel."""
+ # Maps muted channel IDs to their previous overwrites for send_message and add_reactions.
+ # Overwrites are stored as JSON.
+ previous_overwrites = RedisCache()
+
+ # Maps muted channel IDs to POSIX timestamps of when they'll be unsilenced.
+ # A timestamp equal to -1 means it's indefinite.
+ unsilence_timestamps = RedisCache()
+
def __init__(self, bot: Bot):
self.bot = bot
self.scheduler = Scheduler(self.__class__.__name__)
- self.muted_channels = set()
- self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars())
- self._get_instance_vars_event = asyncio.Event()
+ self._init_task = self.bot.loop.create_task(self._async_init())
- async def _get_instance_vars(self) -> None:
- """Get instance variables after they're available to get from the guild."""
+ async def _async_init(self) -> None:
+ """Set instance attributes once the guild is available and reschedule unsilences."""
await self.bot.wait_until_guild_available()
+
guild = self.bot.get_guild(Guild.id)
self._verified_role = guild.get_role(Roles.verified)
self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts)
- self._mod_log_channel = self.bot.get_channel(Channels.mod_log)
- self.notifier = SilenceNotifier(self._mod_log_channel)
- self._get_instance_vars_event.set()
+ self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log))
+ await self._reschedule()
@commands.command(aliases=("hush",))
+ @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True)
async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None:
"""
Silence the current channel for `duration` minutes or `forever`.
@@ -82,18 +107,25 @@ class Silence(commands.Cog):
Duration is capped at 15 minutes, passing forever makes the silence indefinite.
Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start.
"""
- await self._get_instance_vars_event.wait()
- log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.")
- if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration):
- await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.")
- return
- if duration is None:
- await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.")
+ await self._init_task
+
+ channel_info = f"#{ctx.channel} ({ctx.channel.id})"
+ log.debug(f"{ctx.author} is silencing channel {channel_info}.")
+
+ if not await self._set_silence_overwrites(ctx.channel):
+ log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.")
+ await ctx.send(MSG_SILENCE_FAIL)
return
- await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).")
+ await self._schedule_unsilence(ctx, duration)
- self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))
+ if duration is None:
+ self.notifier.add_channel(ctx.channel)
+ log.info(f"Silenced {channel_info} indefinitely.")
+ await ctx.send(MSG_SILENCE_PERMANENT)
+ else:
+ log.info(f"Silenced {channel_info} for {duration} minute(s).")
+ await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration))
@commands.command(aliases=("unhush",))
async def unsilence(self, ctx: Context) -> None:
@@ -102,61 +134,115 @@ class Silence(commands.Cog):
If the channel was silenced indefinitely, notifications for the channel will stop.
"""
- await self._get_instance_vars_event.wait()
+ await self._init_task
log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.")
- if not await self._unsilence(ctx.channel):
- await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.")
+ await self._unsilence_wrapper(ctx.channel)
+
+ @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True)
+ async def _unsilence_wrapper(self, channel: TextChannel) -> None:
+ """Unsilence `channel` and send a success/failure message."""
+ if not await self._unsilence(channel):
+ overwrite = channel.overwrites_for(self._verified_role)
+ if overwrite.send_messages is False or overwrite.add_reactions is False:
+ await channel.send(MSG_UNSILENCE_MANUAL)
+ else:
+ await channel.send(MSG_UNSILENCE_FAIL)
else:
- await ctx.send(f"{Emojis.check_mark} unsilenced current channel.")
+ await channel.send(MSG_UNSILENCE_SUCCESS)
- async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool:
- """
- Silence `channel` for `self._verified_role`.
+ async def _set_silence_overwrites(self, channel: TextChannel) -> bool:
+ """Set silence permission overwrites for `channel` and return True if successful."""
+ overwrite = channel.overwrites_for(self._verified_role)
+ prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions)
- If `persistent` is `True` add `channel` to notifier.
- `duration` is only used for logging; if None is passed `persistent` should be True to not log None.
- Return `True` if channel permissions were changed, `False` otherwise.
- """
- current_overwrite = channel.overwrites_for(self._verified_role)
- if current_overwrite.send_messages is False:
- log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.")
+ if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):
return False
- await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False))
- self.muted_channels.add(channel)
- if persistent:
- log.info(f"Silenced #{channel} ({channel.id}) indefinitely.")
- self.notifier.add_channel(channel)
- return True
-
- log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).")
+
+ overwrite.update(send_messages=False, add_reactions=False)
+ await channel.set_permissions(self._verified_role, overwrite=overwrite)
+ await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites))
+
return True
+ async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None:
+ """Schedule `ctx.channel` to be unsilenced if `duration` is not None."""
+ if duration is None:
+ await self.unsilence_timestamps.set(ctx.channel.id, -1)
+ else:
+ self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))
+ unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration)
+ await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp())
+
async def _unsilence(self, channel: TextChannel) -> bool:
"""
Unsilence `channel`.
- Check if `channel` is silenced through a `PermissionOverwrite`,
- if it is unsilence it and remove it from the notifier.
+ If `channel` has a silence task scheduled or has its previous overwrites cached, unsilence
+ it, cancel the task, and remove it from the notifier. Notify admins if it has a task but
+ not cached overwrites.
+
Return `True` if channel permissions were changed, `False` otherwise.
"""
- current_overwrite = channel.overwrites_for(self._verified_role)
- if current_overwrite.send_messages is False:
- await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None))
- log.info(f"Unsilenced channel #{channel} ({channel.id}).")
- self.scheduler.cancel(channel.id)
- self.notifier.remove_channel(channel)
- self.muted_channels.discard(channel)
- return True
- log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")
- return False
+ prev_overwrites = await self.previous_overwrites.get(channel.id)
+ if channel.id not in self.scheduler and prev_overwrites is None:
+ log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")
+ return False
+
+ overwrite = channel.overwrites_for(self._verified_role)
+ if prev_overwrites is None:
+ log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.")
+ overwrite.update(send_messages=None, add_reactions=None)
+ else:
+ overwrite.update(**json.loads(prev_overwrites))
+
+ await channel.set_permissions(self._verified_role, overwrite=overwrite)
+ log.info(f"Unsilenced channel #{channel} ({channel.id}).")
+
+ self.scheduler.cancel(channel.id)
+ self.notifier.remove_channel(channel)
+ await self.previous_overwrites.delete(channel.id)
+ await self.unsilence_timestamps.delete(channel.id)
+
+ if prev_overwrites is None:
+ await self._mod_alerts_channel.send(
+ f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing "
+ f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` "
+ f"overwrites for {self._verified_role.mention} are at their desired values."
+ )
+
+ return True
+
+ async def _reschedule(self) -> None:
+ """Reschedule unsilencing of active silences and add permanent ones to the notifier."""
+ for channel_id, timestamp in await self.unsilence_timestamps.items():
+ channel = self.bot.get_channel(channel_id)
+ if channel is None:
+ log.info(f"Can't reschedule silence for {channel_id}: channel not found.")
+ continue
+
+ if timestamp == -1:
+ log.info(f"Adding permanent silence for #{channel} ({channel.id}) to the notifier.")
+ self.notifier.add_channel(channel)
+ continue
+
+ dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
+ delta = (dt - datetime.now(tz=timezone.utc)).total_seconds()
+ if delta <= 0:
+ # Suppress the error since it's not being invoked by a user via the command.
+ with suppress(LockedResourceError):
+ await self._unsilence_wrapper(channel)
+ else:
+ log.info(f"Rescheduling silence for #{channel} ({channel.id}).")
+ self.scheduler.schedule_later(delta, channel_id, self._unsilence_wrapper(channel))
def cog_unload(self) -> None:
- """Send alert with silenced channels and cancel scheduled tasks on unload."""
- self.scheduler.cancel_all()
- if self.muted_channels:
- channels_string = ''.join(channel.mention for channel in self.muted_channels)
- message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}"
- asyncio.create_task(self._mod_alerts_channel.send(message))
+ """Cancel the init task and scheduled tasks."""
+ # It's important to wait for _init_task (specifically for _reschedule) to be cancelled
+ # before cancelling scheduled tasks. Otherwise, it's possible for _reschedule to schedule
+ # more tasks after cancel_all has finished, despite _init_task.cancel being called first.
+ # This is cause cancel() on its own doesn't block until the task is cancelled.
+ self._init_task.cancel()
+ self._init_task.add_done_callback(lambda _: self.scheduler.cancel_all())
# This cannot be static (must have a __func__ attribute).
async def cog_check(self, ctx: Context) -> bool:
diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py
index 8b68b8e2d..c2743e136 100644
--- a/bot/exts/moderation/voice_gate.py
+++ b/bot/exts/moderation/voice_gate.py
@@ -1,3 +1,4 @@
+import asyncio
import logging
from contextlib import suppress
from datetime import datetime, timedelta
@@ -21,7 +22,7 @@ FAILED_MESSAGE = (
)
MESSAGE_FIELD_MAP = {
- "verified_at": f"have been verified for less {GateConf.minimum_days_verified} days",
+ "verified_at": f"have been verified for less than {GateConf.minimum_days_verified} days",
"voice_banned": "have an active voice ban infraction",
"total_messages": f"have sent less than {GateConf.minimum_messages} messages",
}
@@ -50,9 +51,6 @@ class VoiceGate(Cog):
- You must have accepted our rules over a certain number of days ago
- You must not be actively banned from using our voice channels
"""
- # Send this as first thing in order to return after sending DM
- await ctx.send(f"{ctx.author.mention}, check your DMs.")
-
try:
data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data")
except ResponseCodeError as e:
@@ -104,21 +102,31 @@ class VoiceGate(Cog):
)
try:
await ctx.author.send(embed=embed)
+ await ctx.send(f"{ctx.author}, please check your DMs.")
except discord.Forbidden:
await ctx.channel.send(ctx.author.mention, embed=embed)
return
self.mod_log.ignore(Event.member_update, ctx.author.id)
- await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed")
embed = discord.Embed(
title="Voice gate passed",
description="You have been granted permission to use voice channels in Python Discord.",
color=Colour.green()
)
+
+ if ctx.author.voice:
+ embed.description += "\n\nPlease reconnect to your voice channel to be granted your new permissions."
+
try:
await ctx.author.send(embed=embed)
+ await ctx.send(f"{ctx.author}, please check your DMs.")
except discord.Forbidden:
await ctx.channel.send(ctx.author.mention, embed=embed)
+
+ # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it.
+ await asyncio.sleep(3)
+ await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed")
+
self.bot.stats.incr("voice_gate.passed")
@Cog.listener()
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index bf4e24661..3113a1149 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -23,7 +23,7 @@ from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
-NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos
+LOCK_NAMESPACE = "reminder"
WHITELISTED_CHANNELS = Guild.reminder_whitelist
MAXIMUM_REMINDERS = 5
@@ -170,7 +170,7 @@ class Reminders(Cog):
log.trace(f"Scheduling new task #{reminder['id']}")
self.schedule_reminder(reminder)
- @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
is_valid, user, channel = self.ensure_valid_reminder(reminder)
@@ -378,7 +378,7 @@ class Reminders(Cog):
mention_ids = [mention.id for mention in mentions]
await self.edit_reminder(ctx, id_, {"mentions": mention_ids})
- @lock_arg(NAMESPACE, "id_", raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)
async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:
"""Edits a reminder with the given payload, then sends a confirmation message."""
if not await self._can_modify(ctx, id_):
@@ -398,7 +398,7 @@ class Reminders(Cog):
await self._reschedule_reminder(reminder)
@remind_group.command("delete", aliases=("remove", "cancel"))
- @lock_arg(NAMESPACE, "id_", raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)
async def delete_reminder(self, ctx: Context, id_: int) -> None:
"""Delete one of your active reminders."""
if not await self._can_modify(ctx, id_):
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index cad451571..41cb00541 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -21,14 +21,12 @@ log = logging.getLogger(__name__)
ESCAPE_REGEX = re.compile("[`\u202E\u200B]{3,}")
FORMATTED_CODE_REGEX = re.compile(
- r"^\s*" # any leading whitespace from the beginning of the string
r"(?P<delim>(?P<block>```)|``?)" # code delimiter: 1-3 backticks; (?P=block) only matches if it's a block
r"(?(block)(?:(?P<lang>[a-z]+)\n)?)" # if we're in a block, match optional language (only letters plus newline)
r"(?:[ \t]*\n)*" # any blank (empty or tabs/spaces only) lines before the code
r"(?P<code>.*?)" # extract all code inside the markup
r"\s*" # any more whitespace before the end of the code markup
- r"(?P=delim)" # match the exact same delimiter from the start again
- r"\s*$", # any trailing whitespace until the end of the string
+ r"(?P=delim)", # match the exact same delimiter from the start again
re.DOTALL | re.IGNORECASE # "." also matches newlines, case insensitive
)
RAW_CODE_REGEX = re.compile(
@@ -41,8 +39,8 @@ RAW_CODE_REGEX = re.compile(
MAX_PASTE_LEN = 10000
# `!eval` command whitelists
-EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric, Channels.code_help_voice, Channels.code_help_voice_2)
-EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use)
+EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric)
+EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice)
EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
SIGKILL = 9
@@ -76,23 +74,32 @@ class Snekbox(Cog):
@staticmethod
def prepare_input(code: str) -> str:
- """Extract code from the Markdown, format it, and insert it into the code template."""
- match = FORMATTED_CODE_REGEX.fullmatch(code)
- if match:
- code, block, lang, delim = match.group("code", "block", "lang", "delim")
- code = textwrap.dedent(code)
- if block:
- info = (f"'{lang}' highlighted" if lang else "plain") + " code block"
+ """
+ Extract code from the Markdown, format it, and insert it into the code template.
+
+ If there is any code block, ignore text outside the code block.
+ Use the first code block, but prefer a fenced code block.
+ If there are several fenced code blocks, concatenate only the fenced code blocks.
+ """
+ if match := list(FORMATTED_CODE_REGEX.finditer(code)):
+ blocks = [block for block in match if block.group("block")]
+
+ if len(blocks) > 1:
+ code = '\n'.join(block.group("code") for block in blocks)
+ info = "several code blocks"
else:
- info = f"{delim}-enclosed inline code"
- log.trace(f"Extracted {info} for evaluation:\n{code}")
+ match = match[0] if len(blocks) == 0 else blocks[0]
+ code, block, lang, delim = match.group("code", "block", "lang", "delim")
+ if block:
+ info = (f"'{lang}' highlighted" if lang else "plain") + " code block"
+ else:
+ info = f"{delim}-enclosed inline code"
else:
- code = textwrap.dedent(RAW_CODE_REGEX.fullmatch(code).group("code"))
- log.trace(
- f"Eval message contains unformatted or badly formatted code, "
- f"stripping whitespace only:\n{code}"
- )
+ code = RAW_CODE_REGEX.fullmatch(code).group("code")
+ info = "unformatted or badly formatted code"
+ code = textwrap.dedent(code)
+ log.trace(f"Extracted {info} for evaluation:\n{code}")
return code
@staticmethod
diff --git a/bot/utils/channel.py b/bot/utils/channel.py
index 851f9e1fe..6bf70bfde 100644
--- a/bot/utils/channel.py
+++ b/bot/utils/channel.py
@@ -2,6 +2,7 @@ import logging
import discord
+from bot import constants
from bot.constants import Categories
log = logging.getLogger(__name__)
@@ -15,6 +16,21 @@ def is_help_channel(channel: discord.TextChannel) -> bool:
return any(is_in_category(channel, category) for category in categories)
+def is_mod_channel(channel: discord.TextChannel) -> bool:
+ """True if `channel` is considered a mod channel."""
+ if channel.id in constants.MODERATION_CHANNELS:
+ log.trace(f"Channel #{channel} is a configured mod channel")
+ return True
+
+ elif any(is_in_category(channel, category) for category in constants.MODERATION_CATEGORIES):
+ log.trace(f"Channel #{channel} is in a configured mod category")
+ return True
+
+ else:
+ log.trace(f"Channel #{channel} is not a mod channel")
+ return False
+
+
def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
"""Return True if `channel` is within a category with `category_id`."""
return getattr(channel, "category_id", None) == category_id
diff --git a/config-default.yml b/config-default.yml
index c712d1eb7..071f6e1ec 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -128,7 +128,9 @@ guild:
help_available: 691405807388196926
help_in_use: 696958401460043776
help_dormant: 691405908919451718
- modmail: 714494672835444826
+ modmail: &MODMAIL 714494672835444826
+ logs: &LOGS 468520609152892958
+ voice: 356013253765234688
channels:
# Public announcement and news channels
@@ -180,7 +182,7 @@ guild:
incidents: 714214212200562749
incidents_archive: 720668923636351037
mods: &MODS 305126844661760000
- mod_alerts: &MOD_ALERTS 473092532147060736
+ mod_alerts: 473092532147060736
mod_spam: &MOD_SPAM 620607373828030464
organisation: &ORGANISATION 551789653284356126
staff_lounge: &STAFF_LOUNGE 464905259261755392
@@ -194,6 +196,7 @@ guild:
# Voice
code_help_voice: 755154969761677312
code_help_voice_2: 766330079135268884
+ voice_chat: 412357430186344448
admins_voice: &ADMINS_VOICE 500734494840717332
staff_voice: &STAFF_VOICE 412375055910043655
@@ -201,10 +204,13 @@ guild:
big_brother_logs: &BB_LOGS 468507907357409333
talent_pool: &TALENT_POOL 534321732593647616
+ moderation_categories:
+ - *MODMAIL
+ - *LOGS
+
moderation_channels:
- *ADMINS
- *ADMIN_SPAM
- - *MOD_ALERTS
- *MODS
- *MOD_SPAM
@@ -494,6 +500,7 @@ python_news:
- 'python-ideas'
- 'python-announce-list'
- 'pypi-announce'
+ - 'python-dev'
channel: *PYNEWS_CHANNEL
webhook: *PYNEWS_WEBHOOK
diff --git a/tests/_autospec.py b/tests/_autospec.py
new file mode 100644
index 000000000..ee2fc1973
--- /dev/null
+++ b/tests/_autospec.py
@@ -0,0 +1,64 @@
+import contextlib
+import functools
+import unittest.mock
+from typing import Callable
+
+
[email protected](unittest.mock._patch.decoration_helper)
+def _decoration_helper(self, patched, args, keywargs):
+ """Skips adding patchings as args if their `dont_pass` attribute is True."""
+ # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added.
+ extra_args = []
+ with contextlib.ExitStack() as exit_stack:
+ for patching in patched.patchings:
+ arg = exit_stack.enter_context(patching)
+ if not getattr(patching, "dont_pass", False):
+ # Only add the patching as an arg if dont_pass is False.
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is unittest.mock.DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield args, keywargs
+
+
[email protected](unittest.mock._patch.copy)
+def _copy(self):
+ """Copy the `dont_pass` attribute along with the standard copy operation."""
+ patcher_copy = _copy.original(self)
+ patcher_copy.dont_pass = getattr(self, "dont_pass", False)
+ return patcher_copy
+
+
+# Monkey-patch the patcher class :)
+_copy.original = unittest.mock._patch.copy
+unittest.mock._patch.copy = _copy
+unittest.mock._patch.decoration_helper = _decoration_helper
+
+
+def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable:
+ """
+ Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.
+
+ If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object.
+ """
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = dict(spec_set=True, autospec=True)
+ kwargs.update(patch_kwargs)
+
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ if not pass_mocks:
+ # A custom attribute to keep track of which patchings should be skipped.
+ patcher.dont_pass = True
+ func = patcher(func)
+ return func
+ return decorator
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 3c2d52ae0..104293d8e 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -1,23 +1,49 @@
+import asyncio
import unittest
+from datetime import datetime, timezone
from unittest import mock
-from unittest.mock import MagicMock, Mock
+from unittest.mock import Mock
+from async_rediscache import RedisSession
from discord import PermissionOverwrite
-from bot.constants import Channels, Emojis, Guild, Roles
-from bot.exts.moderation.silence import Silence, SilenceNotifier
-from tests.helpers import MockBot, MockContext, MockTextChannel
+from bot.constants import Channels, Guild, Roles
+from bot.exts.moderation import silence
+from tests.helpers import MockBot, MockContext, MockTextChannel, autospec
+
+redis_session = None
+redis_loop = asyncio.get_event_loop()
+
+
+def setUpModule(): # noqa: N802
+ """Create and connect to the fakeredis session."""
+ global redis_session
+ redis_session = RedisSession(use_fakeredis=True)
+ redis_loop.run_until_complete(redis_session.connect())
+
+
+def tearDownModule(): # noqa: N802
+ """Close the fakeredis session."""
+ if redis_session:
+ redis_loop.run_until_complete(redis_session.close())
+
+
+# Have to subclass it because builtins can't be patched.
+class PatchedDatetime(datetime):
+ """A datetime object with a mocked now() function."""
+
+ now = mock.create_autospec(datetime, "now")
class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.alert_channel = MockTextChannel()
- self.notifier = SilenceNotifier(self.alert_channel)
+ self.notifier = silence.SilenceNotifier(self.alert_channel)
self.notifier.stop = self.notifier_stop_mock = Mock()
self.notifier.start = self.notifier_start_mock = Mock()
def test_add_channel_adds_channel(self):
- """Channel in FirstHash with current loop is added to internal set."""
+ """Channel is added to `_silenced_channels` with the current loop."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.add_channel(channel)
@@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.notifier_start_mock.assert_not_called()
def test_remove_channel_removes_channel(self):
- """Channel in FirstHash is removed from `_silenced_channels`."""
+ """Channel is removed from `_silenced_channels`."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.remove_channel(channel)
@@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
with self.subTest(current_loop=current_loop):
with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
await self.notifier._notifier()
- self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ")
+ self.alert_channel.send.assert_called_once_with(
+ f"<@&{Roles.moderators}> currently silenced channels: "
+ )
self.alert_channel.send.reset_mock()
async def test_notifier_skips_alert(self):
@@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.alert_channel.send.assert_not_called()
-class SilenceTests(unittest.IsolatedAsyncioTestCase):
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the general functionality of the Silence cog."""
+
+ @autospec(silence, "Scheduler", pass_mocks=False)
def setUp(self) -> None:
self.bot = MockBot()
- self.cog = Silence(self.bot)
- self.ctx = MockContext()
- self.cog._verified_role = None
- # Set event so command callbacks can continue.
- self.cog._get_instance_vars_event.set()
+ self.cog = silence.Silence(self.bot)
- async def test_instance_vars_got_guild(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_guild(self):
"""Bot got guild after it became available."""
- await self.cog._get_instance_vars()
- self.bot.wait_until_guild_available.assert_called_once()
+ await self.cog._async_init()
+ self.bot.wait_until_guild_available.assert_awaited_once()
self.bot.get_guild.assert_called_once_with(Guild.id)
- async def test_instance_vars_got_role(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_role(self):
"""Got `Roles.verified` role from guild."""
- await self.cog._get_instance_vars()
guild = self.bot.get_guild()
- guild.get_role.assert_called_once_with(Roles.verified)
+ guild.get_role.side_effect = lambda id_: Mock(id=id_)
- async def test_instance_vars_got_channels(self):
+ await self.cog._async_init()
+ self.assertEqual(self.cog._verified_role.id, Roles.verified)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_channels(self):
"""Got channels from bot."""
- await self.cog._get_instance_vars()
- self.bot.get_channel.called_once_with(Channels.mod_alerts)
- self.bot.get_channel.called_once_with(Channels.mod_log)
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)
- @mock.patch("bot.exts.moderation.silence.SilenceNotifier")
- async def test_instance_vars_got_notifier(self, notifier):
+ @autospec(silence, "SilenceNotifier")
+ async def test_async_init_got_notifier(self, notifier):
"""Notifier was started with channel."""
- mod_log = MockTextChannel()
- self.bot.get_channel.side_effect = (None, mod_log)
- await self.cog._get_instance_vars()
- notifier.assert_called_once_with(mod_log)
- self.bot.get_channel.side_effect = None
-
- async def test_silence_sent_correct_discord_message(self):
- """Check if proper message was sent when called with duration in channel with previous state."""
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))
+ self.assertEqual(self.cog.notifier, notifier.return_value)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_rescheduled(self):
+ """`_reschedule_` coroutine was awaited."""
+ self.cog._reschedule = mock.create_autospec(self.cog._reschedule)
+ await self.cog._async_init()
+ self.cog._reschedule.assert_awaited_once_with()
+
+ def test_cog_unload_cancelled_tasks(self):
+ """The init task was cancelled."""
+ self.cog._init_task = asyncio.Future()
+ self.cog.cog_unload()
+
+ # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda.
+ self.assertTrue(self.cog._init_task.cancelled())
+
+ @autospec("discord.ext.commands", "has_any_role")
+ @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3))
+ async def test_cog_check(self, role_check):
+ """Role check was called with `MODERATION_ROLES`"""
+ ctx = MockContext()
+ role_check.return_value.predicate = mock.AsyncMock()
+
+ await self.cog.cog_check(ctx)
+ role_check.assert_called_once_with(*(1, 2, 3))
+ role_check.return_value.predicate.assert_awaited_once_with(ctx)
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class RescheduleTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the rescheduling of cached unsilences."""
+
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper)
+
+ with mock.patch.object(self.cog, "_reschedule", autospec=True):
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ async def test_skipped_missing_channel(self):
+ """Did nothing because the channel couldn't be retrieved."""
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)]
+ self.bot.get_channel.return_value = None
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_added_permanent_to_notifier(self):
+ """Permanently silenced channels were added to the notifier."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)]
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_any_call(channels[0])
+ self.cog.notifier.add_channel.assert_any_call(channels[1])
+
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_unsilenced_expired(self):
+ """Unsilenced expired silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)]
+
+ await self.cog._reschedule()
+
+ self.cog._unsilence_wrapper.assert_any_call(channels[0])
+ self.cog._unsilence_wrapper.assert_any_call(channels[1])
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ @mock.patch.object(silence, "datetime", new=PatchedDatetime)
+ async def test_rescheduled_active(self):
+ """Rescheduled active silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)]
+ silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc)
+
+ self.cog._unsilence_wrapper = mock.MagicMock()
+ unsilence_return = self.cog._unsilence_wrapper.return_value
+
+ await self.cog._reschedule()
+
+ # Yuck.
+ calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)]
+ self.cog.scheduler.schedule_later.assert_has_calls(calls)
+
+ unsilence_calls = [mock.call(channel) for channel in channels]
+ self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls)
+
+ self.cog.notifier.add_channel.assert_not_called()
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the silence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ # Avoid unawaited coroutine warnings.
+ self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close()
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
test_cases = (
- (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,),
- (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,),
- (5, f"{Emojis.cross_mark} current channel is already silenced.", False,),
+ (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),
+ (None, silence.MSG_SILENCE_PERMANENT, True,),
+ (5, silence.MSG_SILENCE_FAIL, False,),
)
- for duration, result_message, _silence_patch_return in test_cases:
- with self.subTest(
- silence_duration=duration,
- result_message=result_message,
- starting_unsilenced_state=_silence_patch_return
- ):
- with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
- await self.cog.silence(self.cog, self.ctx, duration)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_unsilence_sent_correct_discord_message(self):
- """Check if proper message was sent when unsilencing channel."""
- test_cases = (
- (True, f"{Emojis.check_mark} unsilenced current channel."),
- (False, f"{Emojis.cross_mark} current channel was not silenced.")
+ for duration, message, was_silenced in test_cases:
+ ctx = MockContext()
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced):
+ with self.subTest(was_silenced=was_silenced, message=message, duration=duration):
+ await self.cog.silence.callback(self.cog, ctx, duration)
+ ctx.send.assert_called_once_with(message)
+
+ async def test_skipped_already_silenced(self):
+ """Permissions were not set and `False` was returned for an already silenced channel."""
+ subtests = (
+ (False, PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (True, PermissionOverwrite(send_messages=True, add_reactions=True)),
+ (True, PermissionOverwrite(send_messages=False, add_reactions=False)),
)
- for _unsilence_patch_return, result_message in test_cases:
- with self.subTest(
- starting_silenced_state=_unsilence_patch_return,
- result_message=result_message
- ):
- with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
- await self.cog.unsilence(self.cog, self.ctx)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_silence_private_for_false(self):
- """Permissions are not set and `False` is returned in an already silenced channel."""
- perm_overwrite = Mock(send_messages=False)
- channel = Mock(overwrites_for=Mock(return_value=perm_overwrite))
-
- self.assertFalse(await self.cog._silence(channel, True, None))
- channel.set_permissions.assert_not_called()
- async def test_silence_private_silenced_channel(self):
- """Channel had `send_message` permissions revoked."""
- channel = MockTextChannel()
- self.assertTrue(await self.cog._silence(channel, False, None))
- channel.set_permissions.assert_called_once()
- self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages'])
+ for contains, overwrite in subtests:
+ with self.subTest(contains=contains, overwrite=overwrite):
+ self.cog.scheduler.__contains__.return_value = contains
+ channel = MockTextChannel()
+ channel.overwrites_for.return_value = overwrite
+
+ self.assertFalse(await self.cog._set_silence_overwrites(channel))
+ channel.set_permissions.assert_not_called()
+
+ async def test_silenced_channel(self):
+ """Channel had `send_message` and `add_reactions` permissions revoked for verified role."""
+ self.assertTrue(await self.cog._set_silence_overwrites(self.channel))
+ self.assertFalse(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite
+ )
- async def test_silence_private_preserves_permissions(self):
- """Previous permissions were preserved when channel was silenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite()
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._silence(channel, False, None)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- async def test_silence_private_notifier(self):
- """Channel should be added to notifier with `persistent` set to `True`, and the other way around."""
- channel = MockTextChannel()
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=True):
- await self.cog._silence(channel, True, None)
- self.cog.notifier.add_channel.assert_called_once()
-
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=False):
- await self.cog._silence(channel, False, None)
- self.cog.notifier.add_channel.assert_not_called()
-
- async def test_silence_private_added_muted_channel(self):
- """Channel was added to `muted_channels` on silence."""
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed."""
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._set_silence_overwrites(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
+ async def test_temp_not_added_to_notifier(self):
+ """Channel was not added to notifier if a duration was set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_indefinite_added_to_notifier(self):
+ """Channel was added to notifier if a duration was not set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), None)
+ self.cog.notifier.add_channel.assert_called_once()
+
+ async def test_silenced_not_added_to_notifier(self):
+ """Channel was not added to the notifier if it was already silenced."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_cached_previous_overwrites(self):
+ """Channel's previous overwrites were cached."""
+ overwrite_json = '{"send_messages": true, "add_reactions": false}'
+ await self.cog._set_silence_overwrites(self.channel)
+ self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json)
+
+ @autospec(silence, "datetime")
+ async def test_cached_unsilence_time(self, datetime_mock):
+ """The UTC POSIX timestamp for the unsilence was cached."""
+ now_timestamp = 100
+ duration = 15
+ timestamp = now_timestamp + duration * 60
+ datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc)
+
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, duration)
+
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp)
+ datetime_mock.now.assert_called_once_with(tz=timezone.utc) # Ensure it's using an aware dt.
+
+ async def test_cached_indefinite_time(self):
+ """A value of -1 was cached for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
+
+ async def test_scheduled_task(self):
+ """An unsilence task was scheduled."""
+ ctx = MockContext(channel=self.channel, invoke=mock.MagicMock())
+
+ await self.cog.silence.callback(self.cog, ctx, 5)
+
+ args = (300, ctx.channel.id, ctx.invoke.return_value)
+ self.cog.scheduler.schedule_later.assert_called_once_with(*args)
+ ctx.invoke.assert_called_once_with(self.cog.unsilence)
+
+ async def test_permanent_not_scheduled(self):
+ """A task was not scheduled for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+
+@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)
+class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the unsilence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot(get_channel=lambda _: MockTextChannel())
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
+ self.cog.previous_overwrites = overwrites_cache
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.cog.scheduler.__contains__.return_value = True
+ overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
+ unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True)
+ test_cases = (
+ (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),
+ )
+ for was_unsilenced, message, overwrite in test_cases:
+ ctx = MockContext()
+ with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite):
+ with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced):
+ ctx.channel.overwrites_for.return_value = overwrite
+ await self.cog.unsilence.callback(self.cog, ctx)
+ ctx.channel.send.assert_called_once_with(message)
+
+ async def test_skipped_already_unsilenced(self):
+ """Permissions were not set and `False` was returned for an already unsilenced channel."""
+ self.cog.scheduler.__contains__.return_value = False
+ self.cog.previous_overwrites.get.return_value = None
channel = MockTextChannel()
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._silence(channel, False, None)
- muted_channels.add.assert_called_once_with(channel)
- async def test_unsilence_private_for_false(self):
- """Permissions are not set and `False` is returned in an unsilenced channel."""
- channel = Mock()
self.assertFalse(await self.cog._unsilence(channel))
channel.set_permissions.assert_not_called()
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_unsilenced_channel(self, _):
- """Channel had `send_message` permissions restored"""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- self.assertTrue(await self.cog._unsilence(channel))
- channel.set_permissions.assert_called_once()
- self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages'])
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_notifier(self, notifier):
- """Channel was removed from `notifier` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- await self.cog._unsilence(channel)
- notifier.remove_channel.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_muted_channel(self, _):
- """Channel was removed from `muted_channels` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._unsilence(channel)
- muted_channels.discard.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_preserves_permissions(self, _):
- """Previous permissions were preserved when channel was unsilenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite(send_messages=False)
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._unsilence(channel)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- @mock.patch.object(Silence, "_mod_alerts_channel", create=True)
- def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):
- """Task for sending an alert was created with present `muted_channels`."""
- with mock.patch.object(self.cog, "muted_channels"):
- self.cog.cog_unload()
- alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")
- asyncio_mock.create_task.assert_called_once_with(alert_channel.send())
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- def test_cog_unload_skips_task_start(self, asyncio_mock):
- """No task created with no channels."""
- self.cog.cog_unload()
- asyncio_mock.create_task.assert_not_called()
+ async def test_restored_overwrites(self):
+ """Channel's `send_message` and `add_reactions` overwrites were restored."""
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
- @mock.patch("discord.ext.commands.has_any_role")
- @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
- async def test_cog_check(self, role_check):
- """Role check is called with `MODERATION_ROLES`"""
- role_check.return_value.predicate = mock.AsyncMock()
- await self.cog.cog_check(self.ctx)
- role_check.assert_called_once_with(*(1, 2, 3))
- role_check.return_value.predicate.assert_awaited_once_with(self.ctx)
+ # Recall that these values are determined by the fixture.
+ self.assertTrue(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+
+ async def test_cache_miss_used_default_overwrites(self):
+ """Both overwrites were set to None due previous values not being found in the cache."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
+
+ self.assertIsNone(self.overwrite.send_messages)
+ self.assertIsNone(self.overwrite.add_reactions)
+
+ async def test_cache_miss_sent_mod_alert(self):
+ """A message was sent to the mod alerts channel."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.cog._mod_alerts_channel.send.assert_awaited_once()
+
+ async def test_removed_notifier(self):
+ """Channel was removed from `notifier`."""
+ await self.cog._unsilence(self.channel)
+ self.cog.notifier.remove_channel.assert_called_once_with(self.channel)
+
+ async def test_deleted_cached_overwrite(self):
+ """Channel was deleted from the overwrites cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_deleted_cached_time(self):
+ """Channel was deleted from the timestamp cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_cancelled_task(self):
+ """The scheduled unsilence task should be cancelled."""
+ await self.cog._unsilence(self.channel)
+ self.cog.scheduler.cancel.assert_called_once_with(self.channel.id)
+
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed, including cache misses."""
+ for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):
+ with self.subTest(overwrite_json=overwrite_json):
+ self.cog.previous_overwrites.get.return_value = overwrite_json
+
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._unsilence(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove these keys because they were modified by the unsilence.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 6601fad2c..9a42d0610 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -52,6 +52,13 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ ('text```print("Hello world!")```text', 'print("Hello world!")', 'code block surrounded by text'),
+ ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```',
+ 'print("Hello world!")\nprint("Hello world!")', 'two code blocks with text in-between'),
+ ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```',
+ 'print("How\'s it going?")', 'code block preceded by inline code'),
+ ('`print("Hello world!")`\ntext\n`print("Hello world!")`',
+ 'print("Hello world!")', 'one inline code block of two')
)
for case, expected, testname in cases:
with self.subTest(msg=f'Extract code from {testname}.'):
diff --git a/tests/helpers.py b/tests/helpers.py
index e47fdf28f..870f66197 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -5,7 +5,7 @@ import itertools
import logging
import unittest.mock
from asyncio import AbstractEventLoop
-from typing import Callable, Iterable, Optional
+from typing import Iterable, Optional
import discord
from aiohttp import ClientSession
@@ -14,6 +14,7 @@ from discord.ext.commands import Context
from bot.api import APIClient
from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
+from tests._autospec import autospec # noqa: F401 other modules import it via this module
for logger in logging.Logger.manager.loggerDict.values():
@@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def autospec(target, *attributes: str, **kwargs) -> Callable:
- """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
- # Caller's kwargs should take priority and overwrite the defaults.
- kwargs = {'spec_set': True, 'autospec': True, **kwargs}
-
- # Import the target if it's a string.
- # This is to support both object and string targets like patch.multiple.
- if type(target) is str:
- target = unittest.mock._importer(target)
-
- def decorator(func):
- for attribute in attributes:
- patcher = unittest.mock.patch.object(target, attribute, **kwargs)
- func = patcher(func)
- return func
- return decorator
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.