aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--LICENSE-THIRD-PARTY88
-rw-r--r--Pipfile2
-rw-r--r--Pipfile.lock25
-rw-r--r--README.md2
-rw-r--r--bot/__main__.py12
-rw-r--r--bot/constants.py27
-rw-r--r--bot/exts/backend/sync/_syncers.py98
-rw-r--r--bot/exts/filters/antimalware.py4
-rw-r--r--bot/exts/filters/antispam.py3
-rw-r--r--bot/exts/fun/duck_pond.py4
-rw-r--r--bot/exts/help_channels.py45
-rw-r--r--bot/exts/info/codeblock/__init__.py8
-rw-r--r--bot/exts/info/codeblock/_cog.py186
-rw-r--r--bot/exts/info/codeblock/_instructions.py184
-rw-r--r--bot/exts/info/codeblock/_parsing.py228
-rw-r--r--bot/exts/info/information.py92
-rw-r--r--bot/exts/info/reddit.py8
-rw-r--r--bot/exts/info/site.py21
-rw-r--r--bot/exts/info/stats.py42
-rw-r--r--bot/exts/moderation/dm_relay.py6
-rw-r--r--bot/exts/moderation/infraction/_scheduler.py23
-rw-r--r--bot/exts/moderation/infraction/_utils.py5
-rw-r--r--bot/exts/moderation/infraction/infractions.py108
-rw-r--r--bot/exts/moderation/infraction/management.py10
-rw-r--r--bot/exts/moderation/infraction/superstarify.py5
-rw-r--r--bot/exts/moderation/silence.py202
-rw-r--r--bot/exts/moderation/verification.py81
-rw-r--r--bot/exts/moderation/voice_gate.py168
-rw-r--r--bot/exts/utils/bot.py326
-rw-r--r--bot/exts/utils/ping.py2
-rw-r--r--bot/exts/utils/reminders.py8
-rw-r--r--bot/exts/utils/snekbox.py6
-rw-r--r--bot/utils/__init__.py4
-rw-r--r--bot/utils/channel.py49
-rw-r--r--bot/utils/helpers.py9
-rw-r--r--bot/utils/messages.py31
-rw-r--r--config-default.yml51
-rw-r--r--docker-compose.yml1
-rw-r--r--tests/_autospec.py64
-rw-r--r--tests/bot/exts/backend/sync/test_users.py120
-rw-r--r--tests/bot/exts/info/test_information.py73
-rw-r--r--tests/bot/exts/moderation/infraction/test_infractions.py148
-rw-r--r--tests/bot/exts/moderation/test_silence.py587
-rw-r--r--tests/helpers.py21
45 files changed, 2242 insertions, 946 deletions
diff --git a/.gitignore b/.gitignore
index fb3156ab1..2074887ad 100644
--- a/.gitignore
+++ b/.gitignore
@@ -110,6 +110,7 @@ ENV/
# Logfiles
log.*
+*.log.*
# Custom user configuration
config.yml
diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY
new file mode 100644
index 000000000..eacd9b952
--- /dev/null
+++ b/LICENSE-THIRD-PARTY
@@ -0,0 +1,88 @@
+---------------------------------------------------------------------------------------------------
+ BSD 3-Clause License
+Applies to:
+ - Copyright (c) 2008-Present, IPython Development Team
+ Copyright (c) 2001-2007, Fernando Perez <[email protected]>
+ Copyright (c) 2001, Janko Hauser <[email protected]>
+ Copyright (c) 2001, Nathaniel Gray <[email protected]>
+ All rights reserved.
+ - bot/exts/info/codeblock/_parsing.py: _RE_PYTHON_REPL and portions of _RE_IPYTHON_REPL
+---------------------------------------------------------------------------------------------------
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+---------------------------------------------------------------------------------------------------
+ PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
+Applies to:
+ - Copyright © 2001-2020 Python Software Foundation. All rights reserved.
+ - tests/_autospec.py: _decoration_helper
+---------------------------------------------------------------------------------------------------
+
+1. This LICENSE AGREEMENT is between the Python Software Foundation
+("PSF"), and the Individual or Organization ("Licensee") accessing and
+otherwise using this software ("Python") in source or binary form and
+its associated documentation.
+
+2. Subject to the terms and conditions of this License Agreement, PSF hereby
+grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
+analyze, test, perform and/or display publicly, prepare derivative works,
+distribute, and otherwise use Python alone or in any derivative version,
+provided, however, that PSF's License Agreement and PSF's notice of copyright,
+i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
+2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation;
+All Rights Reserved" are retained in Python alone or in any derivative version
+prepared by Licensee.
+
+3. In the event Licensee prepares a derivative work that is based on
+or incorporates Python or any part thereof, and wants to make
+the derivative work available to others as provided herein, then
+Licensee hereby agrees to include in any such work a brief summary of
+the changes made to Python.
+
+4. PSF is making Python available to Licensee on an "AS IS"
+basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
+IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
+DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
+FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
+INFRINGE ANY THIRD PARTY RIGHTS.
+
+5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
+FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
+A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
+OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
+
+6. This License Agreement will automatically terminate upon a material
+breach of its terms and conditions.
+
+7. Nothing in this License Agreement shall be deemed to create any
+relationship of agency, partnership, or joint venture between PSF and
+Licensee. This License Agreement does not grant permission to use PSF
+trademarks or trade name in a trademark sense to endorse or promote
+products or services of Licensee, or any third party.
+
+8. By copying, installing or otherwise using Python, Licensee
+agrees to be bound by the terms and conditions of this License
+Agreement.
diff --git a/Pipfile b/Pipfile
index e6f84d911..99fc70b46 100644
--- a/Pipfile
+++ b/Pipfile
@@ -14,7 +14,7 @@ beautifulsoup4 = "~=4.9"
colorama = {version = "~=0.4.3",sys_platform = "== 'win32'"}
coloredlogs = "~=14.0"
deepdiff = "~=4.0"
-discord.py = "~=1.4.0"
+"discord.py" = "~=1.5.0"
feedparser = "~=5.2"
fuzzywuzzy = "~=0.17"
lxml = "~=4.4"
diff --git a/Pipfile.lock b/Pipfile.lock
index 4c63277de..becd85c55 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "644012a1c3fa3e3a30f8b8f8e672c468dfaa155d9e43d26e2be8713c8dc5ebb3"
+ "sha256": "073fd0c51749aafa188fdbe96c5b90dd157cb1d23bdd144801fb0d0a369ffa88"
},
"pipfile-spec": 6,
"requires": {
@@ -18,11 +18,11 @@
"default": {
"aio-pika": {
"hashes": [
- "sha256:4a20d4d941e1f113a950ea529a90bd9159c8d7aafaa1c71e9c707c8c2b526ea6",
- "sha256:7bf3f183df1eb348d007210a0c1a3c5c755f1b3def1a9a395e93f30b91da1daf"
+ "sha256:9773440a89840941ac3099a7720bf9d51e8764a484066b82ede4d395660ff430",
+ "sha256:a8065be3c722eb8f9fff8c0e7590729e7782202cdb9363d9830d7d5d47b45c7c"
],
"index": "pypi",
- "version": "==6.7.0"
+ "version": "==6.7.1"
},
"aiodns": {
"hashes": [
@@ -205,22 +205,13 @@
"index": "pypi",
"version": "==4.3.2"
},
- "discord": {
- "hashes": [
- "sha256:9d4debb4a37845543bd4b92cb195bc53a302797333e768e70344222857ff1559",
- "sha256:ff6653655e342e7721dfb3f10421345fd852c2a33f2cca912b1c39b3778a9429"
- ],
- "index": "pypi",
- "py": "~=1.4.0",
- "version": "==1.0.1"
- },
"discord.py": {
"hashes": [
- "sha256:98ea3096a3585c9c379209926f530808f5fcf4930928d8cfb579d2562d119570",
- "sha256:f9decb3bfa94613d922376288617e6a6f969260923643e2897f4540c34793442"
+ "sha256:3acb61fde0d862ed346a191d69c46021e6063673f63963bc984ae09a685ab211",
+ "sha256:e71089886aa157341644bdecad63a72ff56b44406b1a6467b66db31c8e5a5a15"
],
- "markers": "python_full_version >= '3.5.3'",
- "version": "==1.4.1"
+ "index": "pypi",
+ "version": "==1.5.0"
},
"docutils": {
"hashes": [
diff --git a/README.md b/README.md
index cae7c3454..b37ece296 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Python Utility Bot
-[![Discord](https://img.shields.io/static/v1?label=Python%20Discord&logo=discord&message=%3E60k%20members&color=%237289DA&logoColor=white)](https://discord.gg/2B963hn)
+[![Discord](https://img.shields.io/static/v1?label=Python%20Discord&logo=discord&message=%3E100k%20members&color=%237289DA&logoColor=white)](https://discord.gg/2B963hn)
[![Build Status](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)](https://dev.azure.com/python-discord/Python%20Discord/_build/latest?definitionId=1&branchName=master)
[![Tests](https://img.shields.io/azure-devops/tests/python-discord/Python%20Discord/1?compact_message)](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)
[![Coverage](https://img.shields.io/azure-devops/coverage/python-discord/Python%20Discord/1/master)](https://dev.azure.com/python-discord/Python%20Discord/_apis/build/status/Bot?branchName=master)
diff --git a/bot/__main__.py b/bot/__main__.py
index 152ddbf92..367be1300 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -47,14 +47,22 @@ loop.run_until_complete(redis_session.connect())
# Instantiate the bot.
allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES]
+intents = discord.Intents().all()
+intents.presences = False
+intents.dm_typing = False
+intents.dm_reactions = False
+intents.invites = False
+intents.webhooks = False
+intents.integrations = False
bot = Bot(
redis_session=redis_session,
loop=loop,
command_prefix=when_mentioned_or(constants.Bot.prefix),
- activity=discord.Game(name="Commands: !help"),
+ activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),
case_insensitive=True,
max_messages=10_000,
- allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles)
+ allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
+ intents=intents,
)
# Load extensions.
diff --git a/bot/constants.py b/bot/constants.py
index c21fd52e0..23d5b4304 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -377,6 +377,7 @@ class Categories(metaclass=YAMLGetter):
help_in_use: int
help_dormant: int
modmail: int
+ voice: int
class Channels(metaclass=YAMLGetter):
@@ -392,6 +393,7 @@ class Channels(metaclass=YAMLGetter):
bot_commands: int
change_log: int
code_help_voice: int
+ code_help_voice_2: int
cooldown: int
defcon: int
dev_contrib: int
@@ -424,6 +426,8 @@ class Channels(metaclass=YAMLGetter):
user_event_announcements: int
user_log: int
verification: int
+ voice_chat: int
+ voice_gate: int
voice_log: int
@@ -456,9 +460,11 @@ class Roles(metaclass=YAMLGetter):
owners: int
partners: int
python_community: int
+ sprinters: int
team_leaders: int
unverified: int
verified: int # This is the Developers role on PyDis, here named verified for readability reasons.
+ voice_verified: int
class Guild(metaclass=YAMLGetter):
@@ -467,6 +473,7 @@ class Guild(metaclass=YAMLGetter):
id: int
invite: str # Discord invite, gets embedded in chat
moderation_channels: List[int]
+ moderation_categories: List[int]
moderation_roles: List[int]
modlog_blacklist: List[int]
reminder_whitelist: List[int]
@@ -528,6 +535,15 @@ class BigBrother(metaclass=YAMLGetter):
header_message_limit: int
+class CodeBlock(metaclass=YAMLGetter):
+ section = 'code_block'
+
+ channel_whitelist: List[int]
+ cooldown_channels: List[int]
+ cooldown_seconds: int
+ minimum_lines: int
+
+
class Free(metaclass=YAMLGetter):
section = 'free'
@@ -578,6 +594,14 @@ class Verification(metaclass=YAMLGetter):
kick_confirmation_threshold: float
+class VoiceGate(metaclass=YAMLGetter):
+ section = "voice_gate"
+
+ minimum_days_verified: int
+ minimum_messages: int
+ bot_message_delete_delay: int
+
+
class Event(Enum):
"""
Event names. This does not include every event (for example, raw
@@ -618,6 +642,9 @@ STAFF_ROLES = Guild.staff_roles
# Channel combinations
MODERATION_CHANNELS = Guild.moderation_channels
+# Category combinations
+MODERATION_CATEGORIES = Guild.moderation_categories
+
# Bot replies
NEGATIVE_REPLIES = [
"Noooooo!!",
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index 3d4a09df3..38468c2b1 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -14,7 +14,6 @@ log = logging.getLogger(__name__)
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
-_User = namedtuple('User', ('id', 'name', 'discriminator', 'roles', 'in_guild'))
_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
@@ -134,61 +133,76 @@ class UserSyncer(Syncer):
async def _get_diff(self, guild: Guild) -> _Diff:
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
- users = await self.bot.api_client.get('bot/users')
- # Pack DB roles and guild roles into one common, hashable format.
- # They're hashable so that they're easily comparable with sets later.
- db_users = {
- user_dict['id']: _User(
- roles=tuple(sorted(user_dict.pop('roles'))),
- **user_dict
- )
- for user_dict in users
- }
- guild_users = {
- member.id: _User(
- id=member.id,
- name=member.name,
- discriminator=int(member.discriminator),
- roles=tuple(sorted(role.id for role in member.roles)),
- in_guild=True
- )
- for member in guild.members
- }
+ users_to_create = []
+ users_to_update = []
+ seen_guild_users = set()
+
+ async for db_user in self._get_users():
+ # Store user fields which are to be updated.
+ updated_fields = {}
- users_to_create = set()
- users_to_update = set()
+ def maybe_update(db_field: str, guild_value: t.Union[str, int]) -> None:
+ # Equalize DB user and guild user attributes.
+ if db_user[db_field] != guild_value:
+ updated_fields[db_field] = guild_value
- for db_user in db_users.values():
- guild_user = guild_users.get(db_user.id)
- if guild_user is not None:
- if db_user != guild_user:
- users_to_update.add(guild_user)
+ if guild_user := guild.get_member(db_user["id"]):
+ seen_guild_users.add(guild_user.id)
- elif db_user.in_guild:
+ maybe_update("name", guild_user.name)
+ maybe_update("discriminator", int(guild_user.discriminator))
+ maybe_update("in_guild", True)
+
+ guild_roles = [role.id for role in guild_user.roles]
+ if set(db_user["roles"]) != set(guild_roles):
+ updated_fields["roles"] = guild_roles
+
+ elif db_user["in_guild"]:
# The user is known in the DB but not the guild, and the
# DB currently specifies that the user is a member of the guild.
# This means that the user has left since the last sync.
# Update the `in_guild` attribute of the user on the site
# to signify that the user left.
- new_api_user = db_user._replace(in_guild=False)
- users_to_update.add(new_api_user)
-
- new_user_ids = set(guild_users.keys()) - set(db_users.keys())
- for user_id in new_user_ids:
- # The user is known on the guild but not on the API. This means
- # that the user has joined since the last sync. Create it.
- new_user = guild_users[user_id]
- users_to_create.add(new_user)
+ updated_fields["in_guild"] = False
+
+ if updated_fields:
+ updated_fields["id"] = db_user["id"]
+ users_to_update.append(updated_fields)
+
+ for member in guild.members:
+ if member.id not in seen_guild_users:
+ # The user is known on the guild but not on the API. This means
+ # that the user has joined since the last sync. Create it.
+ new_user = {
+ "id": member.id,
+ "name": member.name,
+ "discriminator": int(member.discriminator),
+ "roles": [role.id for role in member.roles],
+ "in_guild": True
+ }
+ users_to_create.append(new_user)
return _Diff(users_to_create, users_to_update, None)
+ async def _get_users(self) -> t.AsyncIterable:
+ """GET users from database."""
+ query_params = {
+ "page": 1
+ }
+ while query_params["page"]:
+ res = await self.bot.api_client.get("bot/users", params=query_params)
+ for user in res["results"]:
+ yield user
+
+ query_params["page"] = res["next_page_no"]
+
async def _sync(self, diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
- for user in diff.created:
- await self.bot.api_client.post('bot/users', json=user._asdict())
+ if diff.created:
+ await self.bot.api_client.post("bot/users", json=diff.created)
log.trace("Syncing updated users...")
- for user in diff.updated:
- await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
+ if diff.updated:
+ await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated)
diff --git a/bot/exts/filters/antimalware.py b/bot/exts/filters/antimalware.py
index 7894ec48f..26f00e91f 100644
--- a/bot/exts/filters/antimalware.py
+++ b/bot/exts/filters/antimalware.py
@@ -6,7 +6,7 @@ from discord import Embed, Message, NotFound
from discord.ext.commands import Cog
from bot.bot import Bot
-from bot.constants import Channels, STAFF_ROLES, URLs
+from bot.constants import Channels, Filter, URLs
log = logging.getLogger(__name__)
@@ -61,7 +61,7 @@ class AntiMalware(Cog):
# Check if user is staff, if is, return
# Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
- if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles):
+ if hasattr(message.author, "roles") and any(role.id in Filter.role_whitelist for role in message.author.roles):
return
embed = Embed()
diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py
index 4964283f1..af8528a68 100644
--- a/bot/exts/filters/antispam.py
+++ b/bot/exts/filters/antispam.py
@@ -15,7 +15,6 @@ from bot.constants import (
AntiSpam as AntiSpamConfig, Channels,
Colours, DEBUG_MODE, Event, Filter,
Guild as GuildConfig, Icons,
- STAFF_ROLES,
)
from bot.converters import Duration
from bot.exts.moderation.modlog import ModLog
@@ -149,7 +148,7 @@ class AntiSpam(Cog):
or message.guild.id != GuildConfig.id
or message.author.bot
or (message.channel.id in Filter.channel_whitelist and not DEBUG_MODE)
- or (any(role.id in STAFF_ROLES for role in message.author.roles) and not DEBUG_MODE)
+ or (any(role.id in Filter.role_whitelist for role in message.author.roles) and not DEBUG_MODE)
):
return
diff --git a/bot/exts/fun/duck_pond.py b/bot/exts/fun/duck_pond.py
index 82084ea88..48aa2749c 100644
--- a/bot/exts/fun/duck_pond.py
+++ b/bot/exts/fun/duck_pond.py
@@ -22,6 +22,7 @@ class DuckPond(Cog):
self.bot = bot
self.webhook_id = constants.Webhooks.duck_pond
self.webhook = None
+ self.ducked_messages = []
self.bot.loop.create_task(self.fetch_webhook())
self.relay_lock = None
@@ -176,7 +177,8 @@ class DuckPond(Cog):
duck_count = await self.count_ducks(message)
# If we've got more than the required amount of ducks, send the message to the duck_pond.
- if duck_count >= constants.DuckPond.threshold:
+ if duck_count >= constants.DuckPond.threshold and message.id not in self.ducked_messages:
+ self.ducked_messages.append(message.id)
await self.locked_relay(message)
@Cog.listener()
diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py
index f5c9a5dd0..062d4fcfe 100644
--- a/bot/exts/help_channels.py
+++ b/bot/exts/help_channels.py
@@ -14,6 +14,7 @@ from discord.ext import commands
from bot import constants
from bot.bot import Bot
+from bot.utils import channel as channel_utils
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
@@ -378,11 +379,18 @@ class HelpChannels(commands.Cog):
log.trace("Getting the CategoryChannel objects for the help categories.")
try:
- self.available_category = await self.try_get_channel(
- constants.Categories.help_available
+ self.available_category = await channel_utils.try_get_channel(
+ constants.Categories.help_available,
+ self.bot
+ )
+ self.in_use_category = await channel_utils.try_get_channel(
+ constants.Categories.help_in_use,
+ self.bot
+ )
+ self.dormant_category = await channel_utils.try_get_channel(
+ constants.Categories.help_dormant,
+ self.bot
)
- self.in_use_category = await self.try_get_channel(constants.Categories.help_in_use)
- self.dormant_category = await self.try_get_channel(constants.Categories.help_dormant)
except discord.HTTPException:
log.exception("Failed to get a category; cog will be removed")
self.bot.remove_cog(self.qualified_name)
@@ -442,12 +450,6 @@ class HelpChannels(commands.Cog):
return False
return message.author == self.bot.user and bot_msg_desc.strip() == description.strip()
- @staticmethod
- def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
- """Return True if `channel` is within a category with `category_id`."""
- actual_category = getattr(channel, "category", None)
- return actual_category is not None and actual_category.id == category_id
-
async def move_idle_channel(self, channel: discord.TextChannel, has_task: bool = True) -> None:
"""
Make the `channel` dormant if idle or schedule the move if still active.
@@ -498,7 +500,7 @@ class HelpChannels(commands.Cog):
options should be avoided, as it may interfere with the category move we perform.
"""
# Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had.
- category = await self.try_get_channel(category_id)
+ category = await channel_utils.try_get_channel(category_id, self.bot)
payload = [{"id": c.id, "position": c.position} for c in category.channels]
@@ -646,7 +648,7 @@ class HelpChannels(commands.Cog):
channel = message.channel
# Confirm the channel is an in use help channel
- if self.is_in_category(channel, constants.Categories.help_in_use):
+ if channel_utils.is_in_category(channel, constants.Categories.help_in_use):
log.trace(f"Checking if #{channel} ({channel.id}) has been answered.")
# Check if there is an entry in unanswered
@@ -671,7 +673,8 @@ class HelpChannels(commands.Cog):
await self.check_for_answer(message)
- if not self.is_in_category(channel, constants.Categories.help_available) or self.is_excluded_channel(channel):
+ is_available = channel_utils.is_in_category(channel, constants.Categories.help_available)
+ if not is_available or self.is_excluded_channel(channel):
return # Ignore messages outside the Available category or in excluded channels.
log.trace("Waiting for the cog to be ready before processing messages.")
@@ -681,7 +684,7 @@ class HelpChannels(commands.Cog):
async with self.on_message_lock:
log.trace(f"on_message lock acquired for {message.id}.")
- if not self.is_in_category(channel, constants.Categories.help_available):
+ if not channel_utils.is_in_category(channel, constants.Categories.help_available):
log.debug(
f"Message {message.id} will not make #{channel} ({channel.id}) in-use "
f"because another message in the channel already triggered that."
@@ -719,7 +722,7 @@ class HelpChannels(commands.Cog):
The new time for the dormant task is configured with `HelpChannels.deleted_idle_minutes`.
"""
- if not self.is_in_category(msg.channel, constants.Categories.help_in_use):
+ if not channel_utils.is_in_category(msg.channel, constants.Categories.help_in_use):
return
if not await self.is_empty(msg.channel):
@@ -844,18 +847,6 @@ class HelpChannels(commands.Cog):
log.trace(f"Dormant message not found in {channel_info}; sending a new message.")
await channel.send(embed=embed)
- async def try_get_channel(self, channel_id: int) -> discord.abc.GuildChannel:
- """Attempt to get or fetch a channel and return it."""
- log.trace(f"Getting the channel {channel_id}.")
-
- channel = self.bot.get_channel(channel_id)
- if not channel:
- log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
- channel = await self.bot.fetch_channel(channel_id)
-
- log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
- return channel
-
async def pin_wrapper(self, msg_id: int, channel: discord.TextChannel, *, pin: bool) -> bool:
"""
Pin message `msg_id` in `channel` if `pin` is True or unpin if it's False.
diff --git a/bot/exts/info/codeblock/__init__.py b/bot/exts/info/codeblock/__init__.py
new file mode 100644
index 000000000..5c55bc5e3
--- /dev/null
+++ b/bot/exts/info/codeblock/__init__.py
@@ -0,0 +1,8 @@
+from bot.bot import Bot
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeBlockCog cog."""
+ # Defer import to reduce side effects from importing the codeblock package.
+ from bot.exts.info.codeblock._cog import CodeBlockCog
+ bot.add_cog(CodeBlockCog(bot))
diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py
new file mode 100644
index 000000000..1e0feab0d
--- /dev/null
+++ b/bot/exts/info/codeblock/_cog.py
@@ -0,0 +1,186 @@
+import logging
+import time
+from typing import Optional
+
+import discord
+from discord import Message, RawMessageUpdateEvent
+from discord.ext.commands import Cog
+
+from bot import constants
+from bot.bot import Bot
+from bot.exts.filters.token_remover import TokenRemover
+from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE
+from bot.exts.info.codeblock._instructions import get_instructions
+from bot.utils import has_lines
+from bot.utils.channel import is_help_channel
+from bot.utils.messages import wait_for_deletion
+
+log = logging.getLogger(__name__)
+
+
+class CodeBlockCog(Cog, name="Code Block"):
+ """
+ Detect improperly formatted Markdown code blocks and suggest proper formatting.
+
+ There are four basic ways in which a code block is considered improperly formatted:
+
+ 1. The code is not within a code block at all
+ * Ignored if the code is not valid Python or Python REPL code
+ 2. Incorrect characters are used for backticks
+ 3. A language for syntax highlighting is not specified
+ * Ignored if the code is not valid Python or Python REPL code
+ 4. A syntax highlighting language is incorrectly specified
+ * Ignored if the language specified doesn't look like it was meant for Python
+ * This can go wrong in two ways:
+ 1. Spaces before the language
+ 2. No newline immediately following the language
+
+ Messages or code blocks must meet a minimum line count to be detected. Detecting multiple code
+ blocks is supported. However, if at least one code block is correct, then instructions will not
+ be sent even if others are incorrect. When multiple incorrect code blocks are found, only the
+ first one is used as the basis for the instructions sent.
+
+ When an issue is detected, an embed is sent containing specific instructions on fixing what
+ is wrong. If the user edits their message to fix the code block, the instructions will be
+ removed. If they fail to fix the code block with an edit, the instructions will be updated to
+ show what is still incorrect after the user's edit. The embed can be manually deleted with a
+ reaction. Otherwise, it will automatically be removed after 5 minutes.
+
+ The cog only detects messages in whitelisted channels. Channels may also have a cooldown on the
+ instructions being sent. Note all help channels are also whitelisted with cooldowns enabled.
+
+ For configurable parameters, see the `code_block` section in config-default.py.
+ """
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ # Stores allowed channels plus epoch times since the last instructional messages sent.
+ self.channel_cooldowns = dict.fromkeys(constants.CodeBlock.cooldown_channels, 0.0)
+
+ # Maps users' messages to the messages the bot sent with instructions.
+ self.codeblock_message_ids = {}
+
+ @staticmethod
+ def create_embed(instructions: str) -> discord.Embed:
+ """Return an embed which displays code block formatting `instructions`."""
+ return discord.Embed(description=instructions)
+
+ async def get_sent_instructions(self, payload: RawMessageUpdateEvent) -> Optional[Message]:
+ """
+ Return the bot's sent instructions message associated with a user's message `payload`.
+
+ Return None if the message cannot be found. In this case, it's likely the message was
+ deleted either manually via a reaction or automatically by a timer.
+ """
+ log.trace(f"Retrieving instructions message for ID {payload.message_id}")
+ channel = self.bot.get_channel(payload.channel_id)
+
+ try:
+ return await channel.fetch_message(self.codeblock_message_ids[payload.message_id])
+ except discord.NotFound:
+ log.debug("Could not find instructions message; it was probably deleted.")
+ return None
+
+ def is_on_cooldown(self, channel: discord.TextChannel) -> bool:
+ """
+ Return True if an embed was sent too recently for `channel`.
+
+ The cooldown is configured by `constants.CodeBlock.cooldown_seconds`.
+ Note: only channels in the `channel_cooldowns` have cooldowns enabled.
+ """
+ log.trace(f"Checking if #{channel} is on cooldown.")
+ cooldown = constants.CodeBlock.cooldown_seconds
+ return (time.time() - self.channel_cooldowns.get(channel.id, 0)) < cooldown
+
+ def is_valid_channel(self, channel: discord.TextChannel) -> bool:
+ """Return True if `channel` is a help channel, may be on a cooldown, or is whitelisted."""
+ log.trace(f"Checking if #{channel} qualifies for code block detection.")
+ return (
+ is_help_channel(channel)
+ or channel.id in self.channel_cooldowns
+ or channel.id in constants.CodeBlock.channel_whitelist
+ )
+
+ async def send_instructions(self, message: discord.Message, instructions: str) -> None:
+ """
+ Send an embed with `instructions` on fixing an incorrect code block in a `message`.
+
+ The embed will be deleted automatically after 5 minutes.
+ """
+ log.info(f"Sending code block formatting instructions for message {message.id}.")
+
+ embed = self.create_embed(instructions)
+ bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed)
+ self.codeblock_message_ids[message.id] = bot_message.id
+
+ self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot))
+
+ # Increase amount of codeblock correction in stats
+ self.bot.stats.incr("codeblock_corrections")
+
+ def should_parse(self, message: discord.Message) -> bool:
+ """
+ Return True if `message` should be parsed.
+
+ A qualifying message:
+
+ 1. Is not authored by a bot
+ 2. Is in a valid channel
+ 3. Has more than 3 lines
+ 4. Has no bot or webhook token
+ """
+ return (
+ not message.author.bot
+ and self.is_valid_channel(message.channel)
+ and has_lines(message.content, constants.CodeBlock.minimum_lines)
+ and not TokenRemover.find_token_in_message(message)
+ and not WEBHOOK_URL_RE.search(message.content)
+ )
+
+ @Cog.listener()
+ async def on_message(self, msg: Message) -> None:
+ """Detect incorrect Markdown code blocks in `msg` and send instructions to fix them."""
+ if not self.should_parse(msg):
+ log.trace(f"Skipping code block detection of {msg.id}: message doesn't qualify.")
+ return
+
+ # When debugging, ignore cooldowns.
+ if self.is_on_cooldown(msg.channel) and not constants.DEBUG_MODE:
+ log.trace(f"Skipping code block detection of {msg.id}: #{msg.channel} is on cooldown.")
+ return
+
+ instructions = get_instructions(msg.content)
+ if instructions:
+ await self.send_instructions(msg, instructions)
+
+ if msg.channel.id not in constants.CodeBlock.channel_whitelist:
+ log.debug(f"Adding #{msg.channel} to the channel cooldowns.")
+ self.channel_cooldowns[msg.channel.id] = time.time()
+
+ @Cog.listener()
+ async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None:
+ """Delete the instructional message if an edited message had its code blocks fixed."""
+ if payload.message_id not in self.codeblock_message_ids:
+ log.trace(f"Ignoring message edit {payload.message_id}: message isn't being tracked.")
+ return
+
+ if payload.data.get("content") is None or payload.data.get("channel_id") is None:
+ log.trace(f"Ignoring message edit {payload.message_id}: missing content or channel ID.")
+ return
+
+ # Parse the message to see if the code blocks have been fixed.
+ content = payload.data.get("content")
+ instructions = get_instructions(content)
+
+ bot_message = await self.get_sent_instructions(payload)
+ if not bot_message:
+ return
+
+ if not instructions:
+ log.info("User's incorrect code block has been fixed. Removing instructions message.")
+ await bot_message.delete()
+ del self.codeblock_message_ids[payload.message_id]
+ else:
+ log.info("Message edited but still has invalid code blocks; editing the instructions.")
+ await bot_message.edit(embed=self.create_embed(instructions))
diff --git a/bot/exts/info/codeblock/_instructions.py b/bot/exts/info/codeblock/_instructions.py
new file mode 100644
index 000000000..508f157fb
--- /dev/null
+++ b/bot/exts/info/codeblock/_instructions.py
@@ -0,0 +1,184 @@
+"""This module generates and formats instructional messages about fixing Markdown code blocks."""
+
+import logging
+from typing import Optional
+
+from bot.exts.info.codeblock import _parsing
+
+log = logging.getLogger(__name__)
+
+_EXAMPLE_PY = "{lang}\nprint('Hello, world!')" # Make sure to escape any Markdown symbols here.
+_EXAMPLE_CODE_BLOCKS = (
+ "\\`\\`\\`{content}\n\\`\\`\\`\n\n"
+ "**This will result in the following:**\n"
+ "```{content}```"
+)
+
+
+def _get_example(language: str) -> str:
+ """Return an example of a correct code block using `language` for syntax highlighting."""
+ # Determine the example code to put in the code block based on the language specifier.
+ if language.lower() in _parsing.PY_LANG_CODES:
+ log.trace(f"Code block has a Python language specifier `{language}`.")
+ content = _EXAMPLE_PY.format(lang=language)
+ elif language:
+ log.trace(f"Code block has a foreign language specifier `{language}`.")
+ # It's not feasible to determine what would be a valid example for other languages.
+ content = f"{language}\n..."
+ else:
+ log.trace("Code block has no language specifier.")
+ content = "\nHello, world!"
+
+ return _EXAMPLE_CODE_BLOCKS.format(content=content)
+
+
+def _get_bad_ticks_message(code_block: _parsing.CodeBlock) -> Optional[str]:
+ """Return instructions on using the correct ticks for `code_block`."""
+ log.trace("Creating instructions for incorrect code block ticks.")
+
+ valid_ticks = f"\\{_parsing.BACKTICK}" * 3
+ instructions = (
+ "It looks like you are trying to paste code into this channel.\n\n"
+ "You seem to be using the wrong symbols to indicate where the code block should start. "
+ f"The correct symbols would be {valid_ticks}, not `{code_block.tick * 3}`."
+ )
+
+ log.trace("Check if the bad ticks code block also has issues with the language specifier.")
+ addition_msg = _get_bad_lang_message(code_block.content)
+ if not addition_msg and not code_block.language:
+ addition_msg = _get_no_lang_message(code_block.content)
+
+ # Combine the back ticks message with the language specifier message. The latter will
+ # already have an example code block.
+ if addition_msg:
+ log.trace("Language specifier issue found; appending additional instructions.")
+
+ # The first line has double newlines which are not desirable when appending the msg.
+ addition_msg = addition_msg.replace("\n\n", " ", 1)
+
+ # Make the first character of the addition lower case.
+ instructions += "\n\nFurthermore, " + addition_msg[0].lower() + addition_msg[1:]
+ else:
+ log.trace("No issues with the language specifier found.")
+ example_blocks = _get_example(code_block.language)
+ instructions += f"\n\n**Here is an example of how it should look:**\n{example_blocks}"
+
+ return instructions
+
+
+def _get_no_ticks_message(content: str) -> Optional[str]:
+ """If `content` is Python/REPL code, return instructions on using code blocks."""
+ log.trace("Creating instructions for a missing code block.")
+
+ if _parsing.is_python_code(content):
+ example_blocks = _get_example("python")
+ return (
+ "It looks like you're trying to paste code into this channel.\n\n"
+ "Discord has support for Markdown, which allows you to post code with full "
+ "syntax highlighting. Please use these whenever you paste code, as this "
+ "helps improve the legibility and makes it easier for us to help you.\n\n"
+ f"**To do this, use the following method:**\n{example_blocks}"
+ )
+ else:
+ log.trace("Aborting missing code block instructions: content is not Python code.")
+
+
+def _get_bad_lang_message(content: str) -> Optional[str]:
+ """
+ Return instructions on fixing the Python language specifier for a code block.
+
+ If `code_block` does not have a Python language specifier, return None.
+ If there's nothing wrong with the language specifier, return None.
+ """
+ log.trace("Creating instructions for a poorly specified language.")
+
+ info = _parsing.parse_bad_language(content)
+ if not info:
+ log.trace("Aborting bad language instructions: language specified isn't Python.")
+ return
+
+ lines = []
+ language = info.language
+
+ if info.has_leading_spaces:
+ log.trace("Language specifier was preceded by a space.")
+ lines.append(f"Make sure there are no spaces between the back ticks and `{language}`.")
+
+ if not info.has_terminal_newline:
+ log.trace("Language specifier was not followed by a newline.")
+ lines.append(
+ f"Make sure you put your code on a new line following `{language}`. "
+ f"There must not be any spaces after `{language}`."
+ )
+
+ if lines:
+ lines = " ".join(lines)
+ example_blocks = _get_example(language)
+
+ # Note that _get_bad_ticks_message expects the first line to have two newlines.
+ return (
+ f"It looks like you incorrectly specified a language for your code block.\n\n{lines}"
+ f"\n\n**Here is an example of how it should look:**\n{example_blocks}"
+ )
+ else:
+ log.trace("Nothing wrong with the language specifier; no instructions to return.")
+
+
+def _get_no_lang_message(content: str) -> Optional[str]:
+ """
+ Return instructions on specifying a language for a code block.
+
+ If `content` is not valid Python or Python REPL code, return None.
+ """
+ log.trace("Creating instructions for a missing language.")
+
+ if _parsing.is_python_code(content):
+ example_blocks = _get_example("python")
+
+ # Note that _get_bad_ticks_message expects the first line to have two newlines.
+ return (
+ "It looks like you pasted Python code without syntax highlighting.\n\n"
+ "Please use syntax highlighting to improve the legibility of your code and make "
+ "it easier for us to help you.\n\n"
+ f"**To do this, use the following method:**\n{example_blocks}"
+ )
+ else:
+ log.trace("Aborting missing language instructions: content is not Python code.")
+
+
+def get_instructions(content: str) -> Optional[str]:
+ """
+ Parse `content` and return code block formatting instructions if something is wrong.
+
+ Return None if `content` lacks code block formatting issues.
+ """
+ log.trace("Getting formatting instructions.")
+
+ blocks = _parsing.find_code_blocks(content)
+ if blocks is None:
+ log.trace("At least one valid code block found; no instructions to return.")
+ return
+
+ if not blocks:
+ log.trace("No code blocks were found in message.")
+ instructions = _get_no_ticks_message(content)
+ else:
+ log.trace("Searching results for a code block with invalid ticks.")
+ block = next((block for block in blocks if block.tick != _parsing.BACKTICK), None)
+
+ if block:
+ log.trace("A code block exists but has invalid ticks.")
+ instructions = _get_bad_ticks_message(block)
+ else:
+ log.trace("A code block exists but is missing a language.")
+ block = blocks[0]
+
+ # Check for a bad language first to avoid parsing content into an AST.
+ instructions = _get_bad_lang_message(block.content)
+ if not instructions:
+ instructions = _get_no_lang_message(block.content)
+
+ if instructions:
+ instructions += "\nYou can **edit your original message** to correct your code block."
+
+ return instructions
diff --git a/bot/exts/info/codeblock/_parsing.py b/bot/exts/info/codeblock/_parsing.py
new file mode 100644
index 000000000..a98218dfb
--- /dev/null
+++ b/bot/exts/info/codeblock/_parsing.py
@@ -0,0 +1,228 @@
+"""This module provides functions for parsing Markdown code blocks."""
+
+import ast
+import logging
+import re
+import textwrap
+from typing import NamedTuple, Optional, Sequence
+
+from bot import constants
+from bot.utils import has_lines
+
+log = logging.getLogger(__name__)
+
+BACKTICK = "`"
+PY_LANG_CODES = ("python", "pycon", "py") # Order is important; "py" is last cause it's a subset.
+_TICKS = {
+ BACKTICK,
+ "'",
+ '"',
+ "\u00b4", # ACUTE ACCENT
+ "\u2018", # LEFT SINGLE QUOTATION MARK
+ "\u2019", # RIGHT SINGLE QUOTATION MARK
+ "\u2032", # PRIME
+ "\u201c", # LEFT DOUBLE QUOTATION MARK
+ "\u201d", # RIGHT DOUBLE QUOTATION MARK
+ "\u2033", # DOUBLE PRIME
+ "\u3003", # VERTICAL KANA REPEAT MARK UPPER HALF
+}
+
+_RE_PYTHON_REPL = re.compile(r"^(>>>|\.\.\.)( |$)")
+_RE_IPYTHON_REPL = re.compile(r"^((In|Out) \[\d+\]: |\s*\.{3,}: ?)")
+
+_RE_CODE_BLOCK = re.compile(
+ fr"""
+ (?P<ticks>
+ (?P<tick>[{''.join(_TICKS)}]) # Put all ticks into a character class within a group.
+ \2{{2}} # Match previous group 2 more times to ensure the same char.
+ )
+ (?P<lang>[^\W_]+\n)? # Optionally match a language specifier followed by a newline.
+ (?P<code>.+?) # Match the actual code within the block.
+ \1 # Match the same 3 ticks used at the start of the block.
+ """,
+ re.DOTALL | re.VERBOSE
+)
+
+_RE_LANGUAGE = re.compile(
+ fr"""
+ ^(?P<spaces>\s+)? # Optionally match leading spaces from the beginning.
+ (?P<lang>{'|'.join(PY_LANG_CODES)}) # Match a Python language.
+ (?P<newline>\n)? # Optionally match a newline following the language.
+ """,
+ re.IGNORECASE | re.VERBOSE
+)
+
+
+class CodeBlock(NamedTuple):
+ """Represents a Markdown code block."""
+
+ content: str
+ language: str
+ tick: str
+
+
+class BadLanguage(NamedTuple):
+ """Parsed information about a poorly formatted language specifier."""
+
+ language: str
+ has_leading_spaces: bool
+ has_terminal_newline: bool
+
+
+def find_code_blocks(message: str) -> Optional[Sequence[CodeBlock]]:
+ """
+ Find and return all Markdown code blocks in the `message`.
+
+ Code blocks with 3 or fewer lines are excluded.
+
+ If the `message` contains at least one code block with valid ticks and a specified language,
+ return None. This is based on the assumption that if the user managed to get one code block
+ right, they already know how to fix the rest themselves.
+ """
+ log.trace("Finding all code blocks in a message.")
+
+ code_blocks = []
+ for match in _RE_CODE_BLOCK.finditer(message):
+ # Used to ensure non-matched groups have an empty string as the default value.
+ groups = match.groupdict("")
+ language = groups["lang"].strip() # Strip the newline cause it's included in the group.
+
+ if groups["tick"] == BACKTICK and language:
+ log.trace("Message has a valid code block with a language; returning None.")
+ return None
+ elif has_lines(groups["code"], constants.CodeBlock.minimum_lines):
+ code_block = CodeBlock(groups["code"], language, groups["tick"])
+ code_blocks.append(code_block)
+ else:
+ log.trace("Skipped a code block shorter than 4 lines.")
+
+ return code_blocks
+
+
+def _is_python_code(content: str) -> bool:
+ """Return True if `content` is valid Python consisting of more than just expressions."""
+ log.trace("Checking if content is Python code.")
+ try:
+ # Attempt to parse the message into an AST node.
+ # Invalid Python code will raise a SyntaxError.
+ tree = ast.parse(content)
+ except SyntaxError:
+ log.trace("Code is not valid Python.")
+ return False
+
+ # Multiple lines of single words could be interpreted as expressions.
+ # This check is to avoid all nodes being parsed as expressions.
+ # (e.g. words over multiple lines)
+ if not all(isinstance(node, ast.Expr) for node in tree.body):
+ log.trace("Code is valid python.")
+ return True
+ else:
+ log.trace("Code consists only of expressions.")
+ return False
+
+
+def _is_repl_code(content: str, threshold: int = 3) -> bool:
+ """Return True if `content` has at least `threshold` number of (I)Python REPL-like lines."""
+ log.trace(f"Checking if content is (I)Python REPL code using a threshold of {threshold}.")
+
+ repl_lines = 0
+ patterns = (_RE_PYTHON_REPL, _RE_IPYTHON_REPL)
+
+ for line in content.splitlines():
+ # Check the line against all patterns.
+ for pattern in patterns:
+ if pattern.match(line):
+ repl_lines += 1
+
+ # Once a pattern is matched, only use that pattern for the remaining lines.
+ patterns = (pattern,)
+ break
+
+ if repl_lines == threshold:
+ log.trace("Content is (I)Python REPL code.")
+ return True
+
+ log.trace("Content is not (I)Python REPL code.")
+ return False
+
+
+def is_python_code(content: str) -> bool:
+ """Return True if `content` is valid Python code or (I)Python REPL output."""
+ dedented = textwrap.dedent(content)
+
+ # Parse AST twice in case _fix_indentation ends up breaking code due to its inaccuracies.
+ return (
+ _is_python_code(dedented)
+ or _is_repl_code(dedented)
+ or _is_python_code(_fix_indentation(content))
+ )
+
+
+def parse_bad_language(content: str) -> Optional[BadLanguage]:
+ """
+ Return information about a poorly formatted Python language in code block `content`.
+
+ If the language is not Python, return None.
+ """
+ log.trace("Parsing bad language.")
+
+ match = _RE_LANGUAGE.match(content)
+ if not match:
+ return None
+
+ return BadLanguage(
+ language=match["lang"],
+ has_leading_spaces=match["spaces"] is not None,
+ has_terminal_newline=match["newline"] is not None,
+ )
+
+
+def _get_leading_spaces(content: str) -> int:
+ """Return the number of spaces at the start of the first line in `content`."""
+ leading_spaces = 0
+ for char in content:
+ if char == " ":
+ leading_spaces += 1
+ else:
+ return leading_spaces
+
+
+def _fix_indentation(content: str) -> str:
+ """
+ Attempt to fix badly indented code in `content`.
+
+ In most cases, this works like textwrap.dedent. However, if the first line ends with a colon,
+ all subsequent lines are re-indented to only be one level deep relative to the first line.
+ The intent is to fix cases where the leading spaces of the first line of code were accidentally
+ not copied, which makes the first line appear not indented.
+
+ This is fairly naïve and inaccurate. Therefore, it may break some code that was otherwise valid.
+ It's meant to catch really common cases, so that's acceptable. Its flaws are:
+
+ - It assumes that if the first line ends with a colon, it is the start of an indented block
+ - It uses 4 spaces as the indentation, regardless of what the rest of the code uses
+ """
+ lines = content.splitlines(keepends=True)
+
+ # Dedent the first line
+ first_indent = _get_leading_spaces(content)
+ first_line = lines[0][first_indent:]
+
+ # Can't assume there'll be multiple lines cause line counts of edited messages aren't checked.
+ if len(lines) == 1:
+ return first_line
+
+ second_indent = _get_leading_spaces(lines[1])
+
+ # If the first line ends with a colon, all successive lines need to be indented one
+ # additional level (assumes an indent width of 4).
+ if first_line.rstrip().endswith(":"):
+ second_indent -= 4
+
+ # All lines must be dedented at least by the same amount as the first line.
+ first_indent = max(first_indent, second_indent)
+
+ # Dedent the rest of the lines and join them together with the first line.
+ content = first_line + "".join(line[first_indent:] for line in lines[1:])
+
+ return content
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 52239c19e..5aaf85e5a 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -6,15 +6,16 @@ from collections import Counter, defaultdict
from string import Template
from typing import Any, Mapping, Optional, Tuple, Union
-from discord import ChannelType, Colour, CustomActivity, Embed, Guild, Member, Message, Role, Status, utils
+from discord import ChannelType, Colour, Embed, Guild, Message, Role, Status, utils
from discord.abc import GuildChannel
from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group, has_any_role
-from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
+from bot.converters import FetchedMember
from bot.decorators import in_whitelist
from bot.pagination import LinePaginator
+from bot.utils.channel import is_mod_channel
from bot.utils.checks import cooldown_with_role_bypass, has_no_roles_check, in_whitelist_check
from bot.utils.time import time_since
@@ -153,7 +154,9 @@ class Information(Cog):
channel_counts = self.get_channel_type_counts(ctx.guild)
# How many of each user status?
- statuses = Counter(member.status for member in ctx.guild.members)
+ py_invite = await self.bot.fetch_invite(constants.Guild.invite)
+ online_presences = py_invite.approximate_presence_count
+ offline_presences = py_invite.approximate_member_count - online_presences
embed = Embed(colour=Colour.blurple())
# How many staff members and staff channels do we have?
@@ -181,10 +184,8 @@ class Information(Cog):
Roles: {roles}
**Member statuses**
- {constants.Emojis.status_online} {statuses[Status.online]:,}
- {constants.Emojis.status_idle} {statuses[Status.idle]:,}
- {constants.Emojis.status_dnd} {statuses[Status.dnd]:,}
- {constants.Emojis.status_offline} {statuses[Status.offline]:,}
+ {constants.Emojis.status_online} {online_presences:,}
+ {constants.Emojis.status_offline} {offline_presences:,}
""")
).substitute({"channel_counts": channel_counts})
embed.set_thumbnail(url=ctx.guild.icon_url)
@@ -192,7 +193,7 @@ class Information(Cog):
await ctx.send(embed=embed)
@command(name="user", aliases=["user_info", "member", "member_info"])
- async def user_info(self, ctx: Context, user: Member = None) -> None:
+ async def user_info(self, ctx: Context, user: FetchedMember = None) -> None:
"""Returns info about a user."""
if user is None:
user = ctx.author
@@ -207,31 +208,14 @@ class Information(Cog):
embed = await self.create_user_embed(ctx, user)
await ctx.send(embed=embed)
- async def create_user_embed(self, ctx: Context, user: Member) -> Embed:
+ async def create_user_embed(self, ctx: Context, user: FetchedMember) -> Embed:
"""Creates an embed containing information on the `user`."""
- created = time_since(user.created_at, max_units=3)
-
- # Custom status
- custom_status = ''
- for activity in user.activities:
- if isinstance(activity, CustomActivity):
- state = ""
-
- if activity.name:
- state = escape_markdown(activity.name)
-
- emoji = ""
- if activity.emoji:
- # If an emoji is unicode use the emoji, else write the emote like :abc:
- if not activity.emoji.id:
- emoji += activity.emoji.name + " "
- else:
- emoji += f"`:{activity.emoji.name}:` "
+ on_server = bool(ctx.guild.get_member(user.id))
- custom_status = f'Status: {emoji}{state}\n'
+ created = time_since(user.created_at, max_units=3)
name = str(user)
- if user.nick:
+ if on_server and user.nick:
name = f"{user.nick} ({name})"
badges = []
@@ -240,12 +224,16 @@ class Information(Cog):
if is_set and (emoji := getattr(constants.Emojis, f"badge_{badge}", None)):
badges.append(emoji)
- joined = time_since(user.joined_at, max_units=3)
- roles = ", ".join(role.mention for role in user.roles[1:])
-
- desktop_status = STATUS_EMOTES.get(user.desktop_status, constants.Emojis.status_online)
- web_status = STATUS_EMOTES.get(user.web_status, constants.Emojis.status_online)
- mobile_status = STATUS_EMOTES.get(user.mobile_status, constants.Emojis.status_online)
+ if on_server:
+ joined = time_since(user.joined_at, max_units=3)
+ roles = ", ".join(role.mention for role in user.roles[1:])
+ membership = textwrap.dedent(f"""
+ Joined: {joined}
+ Roles: {roles or None}
+ """).strip()
+ else:
+ roles = None
+ membership = "The user is not a member of the server"
fields = [
(
@@ -254,34 +242,16 @@ class Information(Cog):
Created: {created}
Profile: {user.mention}
ID: {user.id}
- {custom_status}
""").strip()
),
(
"Member information",
- textwrap.dedent(f"""
- Joined: {joined}
- Roles: {roles or None}
- """).strip()
+ membership
),
- (
- "Status",
- textwrap.dedent(f"""
- {desktop_status} Desktop
- {web_status} Web
- {mobile_status} Mobile
- """).strip()
- )
]
- # Use getattr to future-proof for commands invoked via DMs.
- show_verbose = (
- ctx.channel.id in constants.MODERATION_CHANNELS
- or getattr(ctx.channel, "category_id", None) == constants.Categories.modmail
- )
-
# Show more verbose output in moderation channels for infractions and nominations
- if show_verbose:
+ if is_mod_channel(ctx.channel):
fields.append(await self.expanded_user_infraction_counts(user))
fields.append(await self.user_nomination_counts(user))
else:
@@ -301,13 +271,13 @@ class Information(Cog):
return embed
- async def basic_user_infraction_counts(self, member: Member) -> Tuple[str, str]:
+ async def basic_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""Gets the total and active infraction counts for the given `member`."""
infractions = await self.bot.api_client.get(
'bot/infractions',
params={
'hidden': 'False',
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
@@ -318,7 +288,7 @@ class Information(Cog):
return "Infractions", infraction_output
- async def expanded_user_infraction_counts(self, member: Member) -> Tuple[str, str]:
+ async def expanded_user_infraction_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""
Gets expanded infraction counts for the given `member`.
@@ -328,7 +298,7 @@ class Information(Cog):
infractions = await self.bot.api_client.get(
'bot/infractions',
params={
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
@@ -359,12 +329,12 @@ class Information(Cog):
return "Infractions", "\n".join(infraction_output)
- async def user_nomination_counts(self, member: Member) -> Tuple[str, str]:
+ async def user_nomination_counts(self, user: FetchedMember) -> Tuple[str, str]:
"""Gets the active and historical nomination counts for the given `member`."""
nominations = await self.bot.api_client.get(
'bot/nominations',
params={
- 'user__id': str(member.id)
+ 'user__id': str(user.id)
}
)
diff --git a/bot/exts/info/reddit.py b/bot/exts/info/reddit.py
index debe40c82..bad4c504d 100644
--- a/bot/exts/info/reddit.py
+++ b/bot/exts/info/reddit.py
@@ -140,7 +140,10 @@ class Reddit(Cog):
# Got appropriate response - process and return.
content = await response.json()
posts = content["data"]["children"]
- return posts[:amount]
+
+ filtered_posts = [post for post in posts if not post["data"]["over_18"]]
+
+ return filtered_posts[:amount]
await asyncio.sleep(3)
@@ -163,12 +166,11 @@ class Reddit(Cog):
amount=amount,
params={"t": time}
)
-
if not posts:
embed.title = random.choice(ERROR_REPLIES)
embed.colour = Colour.red()
embed.description = (
- "Sorry! We couldn't find any posts from that subreddit. "
+ "Sorry! We couldn't find any SFW posts from that subreddit. "
"If this problem persists, please let us know."
)
diff --git a/bot/exts/info/site.py b/bot/exts/info/site.py
index 2d3a3d9f3..fb5b99086 100644
--- a/bot/exts/info/site.py
+++ b/bot/exts/info/site.py
@@ -1,7 +1,7 @@
import logging
from discord import Colour, Embed
-from discord.ext.commands import Cog, Context, group
+from discord.ext.commands import Cog, Context, Greedy, group
from bot.bot import Bot
from bot.constants import URLs
@@ -105,10 +105,9 @@ class Site(Cog):
await ctx.send(embed=embed)
@site_group.command(name="rules", aliases=("r", "rule"), root_aliases=("rules", "rule"))
- async def site_rules(self, ctx: Context, *rules: int) -> None:
+ async def site_rules(self, ctx: Context, rules: Greedy[int]) -> None:
"""Provides a link to all rules or, if specified, displays specific rule(s)."""
- rules_embed = Embed(title='Rules', color=Colour.blurple())
- rules_embed.url = f"{PAGES_URL}/rules"
+ rules_embed = Embed(title='Rules', color=Colour.blurple(), url=f'{PAGES_URL}/rules')
if not rules:
# Rules were not submitted. Return the default description.
@@ -122,15 +121,13 @@ class Site(Cog):
return
full_rules = await self.bot.api_client.get('rules', params={'link_format': 'md'})
- invalid_indices = tuple(
- pick
- for pick in rules
- if pick < 1 or pick > len(full_rules)
- )
- if invalid_indices:
- indices = ', '.join(map(str, invalid_indices))
- await ctx.send(f":x: Invalid rule indices: {indices}")
+ # Remove duplicates and sort the rule indices
+ rules = sorted(set(rules))
+ invalid = ', '.join(str(index) for index in rules if index < 1 or index > len(full_rules))
+
+ if invalid:
+ await ctx.send(f":x: Invalid rule indices: {invalid}")
return
for rule in rules:
diff --git a/bot/exts/info/stats.py b/bot/exts/info/stats.py
index d42f55466..4d8bb645e 100644
--- a/bot/exts/info/stats.py
+++ b/bot/exts/info/stats.py
@@ -1,13 +1,12 @@
import string
-from datetime import datetime
-from discord import Member, Message, Status
+from discord import Member, Message
from discord.ext.commands import Cog, Context
from discord.ext.tasks import loop
from bot.bot import Bot
-from bot.constants import Categories, Channels, Guild, Stats as StatConf
-
+from bot.constants import Categories, Channels, Guild
+from bot.utils.channel import is_in_category
CHANNEL_NAME_OVERRIDES = {
Channels.off_topic_0: "off_topic_0",
@@ -36,8 +35,7 @@ class Stats(Cog):
if message.guild.id != Guild.id:
return
- cat = getattr(message.channel, "category", None)
- if cat is not None and cat.id == Categories.modmail:
+ if is_in_category(message.channel, Categories.modmail):
if message.channel.id != Channels.incidents:
# Do not report modmail channels to stats, there are too many
# of them for interesting statistics to be drawn out of this.
@@ -79,38 +77,6 @@ class Stats(Cog):
self.bot.stats.gauge("guild.total_members", len(member.guild.members))
- @Cog.listener()
- async def on_member_update(self, _before: Member, after: Member) -> None:
- """Update presence estimates on member update."""
- if after.guild.id != Guild.id:
- return
-
- if self.last_presence_update:
- if (datetime.now() - self.last_presence_update).seconds < StatConf.presence_update_timeout:
- return
-
- self.last_presence_update = datetime.now()
-
- online = 0
- idle = 0
- dnd = 0
- offline = 0
-
- for member in after.guild.members:
- if member.status is Status.online:
- online += 1
- elif member.status is Status.dnd:
- dnd += 1
- elif member.status is Status.idle:
- idle += 1
- elif member.status is Status.offline:
- offline += 1
-
- self.bot.stats.gauge("guild.status.online", online)
- self.bot.stats.gauge("guild.status.idle", idle)
- self.bot.stats.gauge("guild.status.do_not_disturb", dnd)
- self.bot.stats.gauge("guild.status.offline", offline)
-
@loop(hours=1)
async def update_guild_boost(self) -> None:
"""Post the server boost level and tier every hour."""
diff --git a/bot/exts/moderation/dm_relay.py b/bot/exts/moderation/dm_relay.py
index 14263e004..4d5142b55 100644
--- a/bot/exts/moderation/dm_relay.py
+++ b/bot/exts/moderation/dm_relay.py
@@ -90,7 +90,11 @@ class DMRelay(Cog):
# Handle any attachments
if message.attachments:
try:
- await send_attachments(message, self.webhook)
+ await send_attachments(
+ message,
+ self.webhook,
+ username=f"{message.author.display_name} ({message.author.id})"
+ )
except (discord.errors.Forbidden, discord.errors.NotFound):
e = discord.Embed(
description=":x: **This message contained an attachment, but it could not be retrieved**",
diff --git a/bot/exts/moderation/infraction/_scheduler.py b/bot/exts/moderation/infraction/_scheduler.py
index 814b17830..bebade0ae 100644
--- a/bot/exts/moderation/infraction/_scheduler.py
+++ b/bot/exts/moderation/infraction/_scheduler.py
@@ -12,11 +12,12 @@ from discord.ext.commands import Context
from bot import constants
from bot.api import ResponseCodeError
from bot.bot import Bot
-from bot.constants import Colours, MODERATION_CHANNELS
+from bot.constants import Colours
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._utils import UserSnowflake
from bot.exts.moderation.modlog import ModLog
from bot.utils import messages, scheduling, time
+from bot.utils.channel import is_mod_channel
log = logging.getLogger(__name__)
@@ -125,7 +126,7 @@ class InfractionScheduler:
log.error(f"Failed to DM {user.id}: could not fetch user (status {e.status})")
else:
# Accordingly display whether the user was successfully notified via DM.
- if await _utils.notify_infraction(user, infr_type, expiry, reason, icon):
+ if await _utils.notify_infraction(user, " ".join(infr_type.split("_")).title(), expiry, reason, icon):
dm_result = ":incoming_envelope: "
dm_log_text = "\nDM: Sent"
@@ -136,11 +137,7 @@ class InfractionScheduler:
)
if reason:
end_msg = f" (reason: {textwrap.shorten(reason, width=1500, placeholder='...')})"
- elif ctx.channel.id not in MODERATION_CHANNELS:
- log.trace(
- f"Infraction #{id_} context is not in a mod channel; omitting infraction count."
- )
- else:
+ elif is_mod_channel(ctx.channel):
log.trace(f"Fetching total infraction count for {user}.")
infractions = await self.bot.api_client.get(
@@ -148,7 +145,7 @@ class InfractionScheduler:
params={"user__id": str(user.id)}
)
total = len(infractions)
- end_msg = f" ({total} infraction{ngettext('', 's', total)} total)"
+ end_msg = f" (#{id_} ; {total} infraction{ngettext('', 's', total)} total)"
# Execute the necessary actions to apply the infraction on Discord.
if action_coro:
@@ -166,7 +163,7 @@ class InfractionScheduler:
log_content = ctx.author.mention
log_title = "failed to apply"
- log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}"
+ log_msg = f"Failed to apply {' '.join(infr_type.split('_'))} infraction #{id_} to {user}"
if isinstance(e, discord.Forbidden):
log.warning(f"{log_msg}: bot lacks permissions.")
else:
@@ -183,7 +180,7 @@ class InfractionScheduler:
log.error(f"Deletion of {infr_type} infraction #{id_} failed with error code {e.status}.")
infr_message = ""
else:
- infr_message = f" **{infr_type}** to {user.mention}{expiry_msg}{end_msg}"
+ infr_message = f" **{' '.join(infr_type.split('_'))}** to {user.mention}{expiry_msg}{end_msg}"
# Send a confirmation message to the invoking context.
log.trace(f"Sending infraction #{id_} confirmation message.")
@@ -195,7 +192,7 @@ class InfractionScheduler:
await self.mod_log.send_log_message(
icon_url=icon,
colour=Colours.soft_red,
- title=f"Infraction {log_title}: {infr_type}",
+ title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",
thumbnail=user.avatar_url_as(static_format="png"),
text=textwrap.dedent(f"""
Member: {messages.format_user(user)}
@@ -272,7 +269,7 @@ class InfractionScheduler:
if send_msg:
log.trace(f"Sending infraction #{id_} pardon confirmation message.")
await ctx.send(
- f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. "
+ f"{dm_emoji}{confirm_msg} infraction **{' '.join(infr_type.split('_'))}** for {user.mention}. "
f"{log_text.get('Failure', '')}"
)
@@ -283,7 +280,7 @@ class InfractionScheduler:
await self.mod_log.send_log_message(
icon_url=_utils.INFRACTION_ICONS[infr_type][1],
colour=Colours.soft_green,
- title=f"Infraction {log_title}: {infr_type}",
+ title=f"Infraction {log_title}: {' '.join(infr_type.split('_'))}",
thumbnail=user.avatar_url_as(static_format="png"),
text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
footer=footer,
diff --git a/bot/exts/moderation/infraction/_utils.py b/bot/exts/moderation/infraction/_utils.py
index 1d91964f1..d0dc3f0a1 100644
--- a/bot/exts/moderation/infraction/_utils.py
+++ b/bot/exts/moderation/infraction/_utils.py
@@ -18,9 +18,10 @@ INFRACTION_ICONS = {
"note": (Icons.user_warn, None),
"superstar": (Icons.superstarify, Icons.unsuperstarify),
"warning": (Icons.user_warn, None),
+ "voice_ban": (Icons.voice_state_red, Icons.voice_state_green),
}
RULES_URL = "https://pythondiscord.com/pages/rules"
-APPEALABLE_INFRACTIONS = ("ban", "mute")
+APPEALABLE_INFRACTIONS = ("ban", "mute", "voice_ban")
# Type aliases
UserObject = t.Union[discord.Member, discord.User]
@@ -154,7 +155,7 @@ async def notify_infraction(
log.trace(f"Sending {user} a DM about their {infr_type} infraction.")
text = INFRACTION_DESCRIPTION_TEMPLATE.format(
- type=infr_type.capitalize(),
+ type=infr_type.title(),
expires=expires_at or "N/A",
reason=reason or "No reason provided."
)
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index a8b3feb38..746d4e154 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -31,6 +31,7 @@ class Infractions(InfractionScheduler, commands.Cog):
self.category = "Moderation"
self._muted_role = discord.Object(constants.Roles.muted)
+ self._voice_verified_role = discord.Object(constants.Roles.voice_verified)
@commands.Cog.listener()
async def on_member_join(self, member: Member) -> None:
@@ -71,6 +72,28 @@ class Infractions(InfractionScheduler, commands.Cog):
"""Permanently ban a user for the given reason and stop watching them with Big Brother."""
await self.apply_ban(ctx, user, reason)
+ @command(aliases=('pban',))
+ async def purgeban(
+ self,
+ ctx: Context,
+ user: FetchedMember,
+ purge_days: t.Optional[int] = 1,
+ *,
+ reason: t.Optional[str] = None
+ ) -> None:
+ """
+ Same as ban but removes all their messages for the given number of days, default being 1.
+
+ `purge_days` can only be values between 0 and 7.
+ Anything outside these bounds are automatically adjusted to their respective limits.
+ """
+ await self.apply_ban(ctx, user, reason, max(min(purge_days, 7), 0))
+
+ @command(aliases=('vban',))
+ async def voiceban(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str]) -> None:
+ """Permanently ban user from using voice channels."""
+ await self.apply_voice_ban(ctx, user, reason)
+
# endregion
# region: Temporary infractions
@@ -119,6 +142,32 @@ class Infractions(InfractionScheduler, commands.Cog):
"""
await self.apply_ban(ctx, user, reason, expires_at=duration)
+ @command(aliases=("tempvban", "tvban"))
+ async def tempvoiceban(
+ self,
+ ctx: Context,
+ user: FetchedMember,
+ duration: Expiry,
+ *,
+ reason: t.Optional[str]
+ ) -> None:
+ """
+ Temporarily voice ban a user for the given reason and duration.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+ """
+ await self.apply_voice_ban(ctx, user, reason, expires_at=duration)
+
# endregion
# region: Permanent shadow infractions
@@ -208,6 +257,11 @@ class Infractions(InfractionScheduler, commands.Cog):
"""Prematurely end the active ban infraction for the user."""
await self.pardon_infraction(ctx, "ban", user)
+ @command(aliases=("uvban",))
+ async def unvoiceban(self, ctx: Context, user: FetchedMember) -> None:
+ """Prematurely end the active voice ban infraction for the user."""
+ await self.pardon_infraction(ctx, "voice_ban", user)
+
# endregion
# region: Base apply functions
@@ -246,7 +300,14 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user, action)
@respect_role_hierarchy(member_arg=2)
- async def apply_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None:
+ async def apply_ban(
+ self,
+ ctx: Context,
+ user: UserSnowflake,
+ reason: t.Optional[str],
+ purge_days: t.Optional[int] = 0,
+ **kwargs
+ ) -> None:
"""
Apply a ban infraction with kwargs passed to `post_infraction`.
@@ -278,7 +339,7 @@ class Infractions(InfractionScheduler, commands.Cog):
if reason:
reason = textwrap.shorten(reason, width=512, placeholder="...")
- action = ctx.guild.ban(user, reason=reason, delete_message_days=0)
+ action = ctx.guild.ban(user, reason=reason, delete_message_days=purge_days)
await self.apply_infraction(ctx, infraction, user, action)
if infraction.get('expires_at') is not None:
@@ -295,6 +356,26 @@ class Infractions(InfractionScheduler, commands.Cog):
bb_reason = "User has been permanently banned from the server. Automatically removed."
await bb_cog.apply_unwatch(ctx, user, bb_reason, send_message=False)
+ @respect_role_hierarchy(member_arg=2)
+ async def apply_voice_ban(self, ctx: Context, user: UserSnowflake, reason: t.Optional[str], **kwargs) -> None:
+ """Apply a voice ban infraction with kwargs passed to `post_infraction`."""
+ if await _utils.get_active_infraction(ctx, user, "voice_ban"):
+ return
+
+ infraction = await _utils.post_infraction(ctx, user, "voice_ban", reason, active=True, **kwargs)
+ if infraction is None:
+ return
+
+ self.mod_log.ignore(Event.member_update, user.id)
+
+ if reason:
+ reason = textwrap.shorten(reason, width=512, placeholder="...")
+
+ await user.move_to(None, reason="Disconnected from voice to apply voiceban.")
+
+ action = user.remove_roles(self._voice_verified_role, reason=reason)
+ await self.apply_infraction(ctx, infraction, user, action)
+
# endregion
# region: Base pardon functions
@@ -339,6 +420,27 @@ class Infractions(InfractionScheduler, commands.Cog):
return log_text
+ async def pardon_voice_ban(self, user_id: int, guild: discord.Guild, reason: t.Optional[str]) -> t.Dict[str, str]:
+ """Add Voice Verified role back to user, DM them a notification, and return a log dict."""
+ user = guild.get_member(user_id)
+ log_text = {}
+
+ if user:
+ # DM user about infraction expiration
+ notified = await _utils.notify_pardon(
+ user=user,
+ title="Voice ban ended",
+ content="You have been unbanned and can verify yourself again in the server.",
+ icon_url=_utils.INFRACTION_ICONS["voice_ban"][1]
+ )
+
+ log_text["Member"] = format_user(user)
+ log_text["DM"] = "Sent" if notified else "**Failed**"
+ else:
+ log_text["Info"] = "User was not found in the guild."
+
+ return log_text
+
async def _pardon_action(self, infraction: _utils.Infraction) -> t.Optional[t.Dict[str, str]]:
"""
Execute deactivation steps specific to the infraction's type and return a log dict.
@@ -353,6 +455,8 @@ class Infractions(InfractionScheduler, commands.Cog):
return await self.pardon_mute(user_id, guild, reason)
elif infraction["type"] == "ban":
return await self.pardon_ban(user_id, guild, reason)
+ elif infraction["type"] == "voice_ban":
+ return await self.pardon_voice_ban(user_id, guild, reason)
# endregion
diff --git a/bot/exts/moderation/infraction/management.py b/bot/exts/moderation/infraction/management.py
index cdab1a6c7..394f63da3 100644
--- a/bot/exts/moderation/infraction/management.py
+++ b/bot/exts/moderation/infraction/management.py
@@ -15,7 +15,7 @@ from bot.exts.moderation.infraction.infractions import Infractions
from bot.exts.moderation.modlog import ModLog
from bot.pagination import LinePaginator
from bot.utils import messages, time
-from bot.utils.checks import in_whitelist_check
+from bot.utils.channel import is_mod_channel
log = logging.getLogger(__name__)
@@ -295,13 +295,7 @@ class ModManagement(commands.Cog):
"""Only allow moderators inside moderator channels to invoke the commands in this cog."""
checks = [
await commands.has_any_role(*constants.MODERATION_ROLES).predicate(ctx),
- in_whitelist_check(
- ctx,
- channels=constants.MODERATION_CHANNELS,
- categories=[constants.Categories.modmail],
- redirect=None,
- fail_silently=True,
- )
+ is_mod_channel(ctx.channel)
]
return all(checks)
diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py
index eec63f5b3..adfe42fcd 100644
--- a/bot/exts/moderation/infraction/superstarify.py
+++ b/bot/exts/moderation/infraction/superstarify.py
@@ -135,7 +135,8 @@ class Superstarify(InfractionScheduler, Cog):
return
# Post the infraction to the API
- reason = reason or f"old nick: {member.display_name}"
+ old_nick = member.display_name
+ reason = reason or f"old nick: {old_nick}"
infraction = await _utils.post_infraction(ctx, member, "superstar", reason, duration, active=True)
id_ = infraction["id"]
@@ -148,7 +149,7 @@ class Superstarify(InfractionScheduler, Cog):
await member.edit(nick=forced_nick, reason=reason)
self.schedule_expiration(infraction)
- old_nick = escape_markdown(member.display_name)
+ old_nick = escape_markdown(old_nick)
forced_nick = escape_markdown(forced_nick)
# Send a DM to the user to notify them of their new infraction.
diff --git a/bot/exts/moderation/silence.py b/bot/exts/moderation/silence.py
index ac0c1c85e..e6712b3b6 100644
--- a/bot/exts/moderation/silence.py
+++ b/bot/exts/moderation/silence.py
@@ -1,8 +1,11 @@
-import asyncio
+import json
import logging
from contextlib import suppress
+from datetime import datetime, timedelta, timezone
+from operator import attrgetter
from typing import Optional
+from async_rediscache import RedisCache
from discord import TextChannel
from discord.ext import commands, tasks
from discord.ext.commands import Context
@@ -10,10 +13,25 @@ from discord.ext.commands import Context
from bot.bot import Bot
from bot.constants import Channels, Emojis, Guild, MODERATION_ROLES, Roles
from bot.converters import HushDurationConverter
+from bot.utils.lock import LockedResourceError, lock_arg
from bot.utils.scheduling import Scheduler
log = logging.getLogger(__name__)
+LOCK_NAMESPACE = "silence"
+
+MSG_SILENCE_FAIL = f"{Emojis.cross_mark} current channel is already silenced."
+MSG_SILENCE_PERMANENT = f"{Emojis.check_mark} silenced current channel indefinitely."
+MSG_SILENCE_SUCCESS = f"{Emojis.check_mark} silenced current channel for {{duration}} minute(s)."
+
+MSG_UNSILENCE_FAIL = f"{Emojis.cross_mark} current channel was not silenced."
+MSG_UNSILENCE_MANUAL = (
+ f"{Emojis.cross_mark} current channel was not unsilenced because the current overwrites were "
+ f"set manually or the cache was prematurely cleared. "
+ f"Please edit the overwrites manually to unsilence."
+)
+MSG_UNSILENCE_SUCCESS = f"{Emojis.check_mark} unsilenced current channel."
+
class SilenceNotifier(tasks.Loop):
"""Loop notifier for posting notices to `alert_channel` containing added channels."""
@@ -56,25 +74,32 @@ class SilenceNotifier(tasks.Loop):
class Silence(commands.Cog):
"""Commands for stopping channel messages for `verified` role in a channel."""
+ # Maps muted channel IDs to their previous overwrites for send_message and add_reactions.
+ # Overwrites are stored as JSON.
+ previous_overwrites = RedisCache()
+
+ # Maps muted channel IDs to POSIX timestamps of when they'll be unsilenced.
+ # A timestamp equal to -1 means it's indefinite.
+ unsilence_timestamps = RedisCache()
+
def __init__(self, bot: Bot):
self.bot = bot
self.scheduler = Scheduler(self.__class__.__name__)
- self.muted_channels = set()
- self._get_instance_vars_task = self.bot.loop.create_task(self._get_instance_vars())
- self._get_instance_vars_event = asyncio.Event()
+ self._init_task = self.bot.loop.create_task(self._async_init())
- async def _get_instance_vars(self) -> None:
- """Get instance variables after they're available to get from the guild."""
+ async def _async_init(self) -> None:
+ """Set instance attributes once the guild is available and reschedule unsilences."""
await self.bot.wait_until_guild_available()
+
guild = self.bot.get_guild(Guild.id)
self._verified_role = guild.get_role(Roles.verified)
self._mod_alerts_channel = self.bot.get_channel(Channels.mod_alerts)
- self._mod_log_channel = self.bot.get_channel(Channels.mod_log)
- self.notifier = SilenceNotifier(self._mod_log_channel)
- self._get_instance_vars_event.set()
+ self.notifier = SilenceNotifier(self.bot.get_channel(Channels.mod_log))
+ await self._reschedule()
@commands.command(aliases=("hush",))
+ @lock_arg(LOCK_NAMESPACE, "ctx", attrgetter("channel"), raise_error=True)
async def silence(self, ctx: Context, duration: HushDurationConverter = 10) -> None:
"""
Silence the current channel for `duration` minutes or `forever`.
@@ -82,18 +107,25 @@ class Silence(commands.Cog):
Duration is capped at 15 minutes, passing forever makes the silence indefinite.
Indefinitely silenced channels get added to a notifier which posts notices every 15 minutes from the start.
"""
- await self._get_instance_vars_event.wait()
- log.debug(f"{ctx.author} is silencing channel #{ctx.channel}.")
- if not await self._silence(ctx.channel, persistent=(duration is None), duration=duration):
- await ctx.send(f"{Emojis.cross_mark} current channel is already silenced.")
- return
- if duration is None:
- await ctx.send(f"{Emojis.check_mark} silenced current channel indefinitely.")
+ await self._init_task
+
+ channel_info = f"#{ctx.channel} ({ctx.channel.id})"
+ log.debug(f"{ctx.author} is silencing channel {channel_info}.")
+
+ if not await self._set_silence_overwrites(ctx.channel):
+ log.info(f"Tried to silence channel {channel_info} but the channel was already silenced.")
+ await ctx.send(MSG_SILENCE_FAIL)
return
- await ctx.send(f"{Emojis.check_mark} silenced current channel for {duration} minute(s).")
+ await self._schedule_unsilence(ctx, duration)
- self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))
+ if duration is None:
+ self.notifier.add_channel(ctx.channel)
+ log.info(f"Silenced {channel_info} indefinitely.")
+ await ctx.send(MSG_SILENCE_PERMANENT)
+ else:
+ log.info(f"Silenced {channel_info} for {duration} minute(s).")
+ await ctx.send(MSG_SILENCE_SUCCESS.format(duration=duration))
@commands.command(aliases=("unhush",))
async def unsilence(self, ctx: Context) -> None:
@@ -102,61 +134,115 @@ class Silence(commands.Cog):
If the channel was silenced indefinitely, notifications for the channel will stop.
"""
- await self._get_instance_vars_event.wait()
+ await self._init_task
log.debug(f"Unsilencing channel #{ctx.channel} from {ctx.author}'s command.")
- if not await self._unsilence(ctx.channel):
- await ctx.send(f"{Emojis.cross_mark} current channel was not silenced.")
+ await self._unsilence_wrapper(ctx.channel)
+
+ @lock_arg(LOCK_NAMESPACE, "channel", raise_error=True)
+ async def _unsilence_wrapper(self, channel: TextChannel) -> None:
+ """Unsilence `channel` and send a success/failure message."""
+ if not await self._unsilence(channel):
+ overwrite = channel.overwrites_for(self._verified_role)
+ if overwrite.send_messages is False or overwrite.add_reactions is False:
+ await channel.send(MSG_UNSILENCE_MANUAL)
+ else:
+ await channel.send(MSG_UNSILENCE_FAIL)
else:
- await ctx.send(f"{Emojis.check_mark} unsilenced current channel.")
+ await channel.send(MSG_UNSILENCE_SUCCESS)
- async def _silence(self, channel: TextChannel, persistent: bool, duration: Optional[int]) -> bool:
- """
- Silence `channel` for `self._verified_role`.
+ async def _set_silence_overwrites(self, channel: TextChannel) -> bool:
+ """Set silence permission overwrites for `channel` and return True if successful."""
+ overwrite = channel.overwrites_for(self._verified_role)
+ prev_overwrites = dict(send_messages=overwrite.send_messages, add_reactions=overwrite.add_reactions)
- If `persistent` is `True` add `channel` to notifier.
- `duration` is only used for logging; if None is passed `persistent` should be True to not log None.
- Return `True` if channel permissions were changed, `False` otherwise.
- """
- current_overwrite = channel.overwrites_for(self._verified_role)
- if current_overwrite.send_messages is False:
- log.info(f"Tried to silence channel #{channel} ({channel.id}) but the channel was already silenced.")
+ if channel.id in self.scheduler or all(val is False for val in prev_overwrites.values()):
return False
- await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=False))
- self.muted_channels.add(channel)
- if persistent:
- log.info(f"Silenced #{channel} ({channel.id}) indefinitely.")
- self.notifier.add_channel(channel)
- return True
-
- log.info(f"Silenced #{channel} ({channel.id}) for {duration} minute(s).")
+
+ overwrite.update(send_messages=False, add_reactions=False)
+ await channel.set_permissions(self._verified_role, overwrite=overwrite)
+ await self.previous_overwrites.set(channel.id, json.dumps(prev_overwrites))
+
return True
+ async def _schedule_unsilence(self, ctx: Context, duration: Optional[int]) -> None:
+ """Schedule `ctx.channel` to be unsilenced if `duration` is not None."""
+ if duration is None:
+ await self.unsilence_timestamps.set(ctx.channel.id, -1)
+ else:
+ self.scheduler.schedule_later(duration * 60, ctx.channel.id, ctx.invoke(self.unsilence))
+ unsilence_time = datetime.now(tz=timezone.utc) + timedelta(minutes=duration)
+ await self.unsilence_timestamps.set(ctx.channel.id, unsilence_time.timestamp())
+
async def _unsilence(self, channel: TextChannel) -> bool:
"""
Unsilence `channel`.
- Check if `channel` is silenced through a `PermissionOverwrite`,
- if it is unsilence it and remove it from the notifier.
+ If `channel` has a silence task scheduled or has its previous overwrites cached, unsilence
+ it, cancel the task, and remove it from the notifier. Notify admins if it has a task but
+ not cached overwrites.
+
Return `True` if channel permissions were changed, `False` otherwise.
"""
- current_overwrite = channel.overwrites_for(self._verified_role)
- if current_overwrite.send_messages is False:
- await channel.set_permissions(self._verified_role, **dict(current_overwrite, send_messages=None))
- log.info(f"Unsilenced channel #{channel} ({channel.id}).")
- self.scheduler.cancel(channel.id)
- self.notifier.remove_channel(channel)
- self.muted_channels.discard(channel)
- return True
- log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")
- return False
+ prev_overwrites = await self.previous_overwrites.get(channel.id)
+ if channel.id not in self.scheduler and prev_overwrites is None:
+ log.info(f"Tried to unsilence channel #{channel} ({channel.id}) but the channel was not silenced.")
+ return False
+
+ overwrite = channel.overwrites_for(self._verified_role)
+ if prev_overwrites is None:
+ log.info(f"Missing previous overwrites for #{channel} ({channel.id}); defaulting to None.")
+ overwrite.update(send_messages=None, add_reactions=None)
+ else:
+ overwrite.update(**json.loads(prev_overwrites))
+
+ await channel.set_permissions(self._verified_role, overwrite=overwrite)
+ log.info(f"Unsilenced channel #{channel} ({channel.id}).")
+
+ self.scheduler.cancel(channel.id)
+ self.notifier.remove_channel(channel)
+ await self.previous_overwrites.delete(channel.id)
+ await self.unsilence_timestamps.delete(channel.id)
+
+ if prev_overwrites is None:
+ await self._mod_alerts_channel.send(
+ f"<@&{Roles.admins}> Restored overwrites with default values after unsilencing "
+ f"{channel.mention}. Please check that the `Send Messages` and `Add Reactions` "
+ f"overwrites for {self._verified_role.mention} are at their desired values."
+ )
+
+ return True
+
+ async def _reschedule(self) -> None:
+ """Reschedule unsilencing of active silences and add permanent ones to the notifier."""
+ for channel_id, timestamp in await self.unsilence_timestamps.items():
+ channel = self.bot.get_channel(channel_id)
+ if channel is None:
+ log.info(f"Can't reschedule silence for {channel_id}: channel not found.")
+ continue
+
+ if timestamp == -1:
+ log.info(f"Adding permanent silence for #{channel} ({channel.id}) to the notifier.")
+ self.notifier.add_channel(channel)
+ continue
+
+ dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
+ delta = (dt - datetime.now(tz=timezone.utc)).total_seconds()
+ if delta <= 0:
+ # Suppress the error since it's not being invoked by a user via the command.
+ with suppress(LockedResourceError):
+ await self._unsilence_wrapper(channel)
+ else:
+ log.info(f"Rescheduling silence for #{channel} ({channel.id}).")
+ self.scheduler.schedule_later(delta, channel_id, self._unsilence_wrapper(channel))
def cog_unload(self) -> None:
- """Send alert with silenced channels and cancel scheduled tasks on unload."""
- self.scheduler.cancel_all()
- if self.muted_channels:
- channels_string = ''.join(channel.mention for channel in self.muted_channels)
- message = f"<@&{Roles.moderators}> channels left silenced on cog unload: {channels_string}"
- asyncio.create_task(self._mod_alerts_channel.send(message))
+ """Cancel the init task and scheduled tasks."""
+ # It's important to wait for _init_task (specifically for _reschedule) to be cancelled
+ # before cancelling scheduled tasks. Otherwise, it's possible for _reschedule to schedule
+ # more tasks after cancel_all has finished, despite _init_task.cancel being called first.
+ # This is cause cancel() on its own doesn't block until the task is cancelled.
+ self._init_task.cancel()
+ self._init_task.add_done_callback(lambda _: self.scheduler.cancel_all())
# This cannot be static (must have a __func__ attribute).
async def cog_check(self, ctx: Context) -> bool:
diff --git a/bot/exts/moderation/verification.py b/bot/exts/moderation/verification.py
index 206556483..c599156d0 100644
--- a/bot/exts/moderation/verification.py
+++ b/bot/exts/moderation/verification.py
@@ -11,6 +11,7 @@ from discord.ext.commands import Cog, Context, command, group, has_any_role
from discord.utils import snowflake_time
from bot import constants
+from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.decorators import has_no_roles, in_whitelist
from bot.exts.moderation.modlog import ModLog
@@ -53,6 +54,23 @@ If you'd like to unsubscribe from the announcement notifications, simply send `!
<#{constants.Channels.bot_commands}>.
"""
+ALTERNATE_VERIFIED_MESSAGE = f"""
+Thanks for accepting our rules!
+
+You can find a copy of our rules for reference at <https://pythondiscord.com/pages/rules>.
+
+Additionally, if you'd like to receive notifications for the announcements \
+we post in <#{constants.Channels.announcements}>
+from time to time, you can send `!subscribe` to <#{constants.Channels.bot_commands}> at any time \
+to assign yourself the **Announcements** role. We'll mention this role every time we make an announcement.
+
+If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \
+<#{constants.Channels.bot_commands}>.
+
+To introduce you to our community, we've made the following video:
+https://youtu.be/ZH26PuX3re0
+"""
+
# Sent via DMs to users kicked for failing to verify
KICKED_MESSAGE = f"""
Hi! You have been automatically kicked from Python Discord as you have failed to accept our rules \
@@ -156,6 +174,9 @@ class Verification(Cog):
# ]
task_cache = RedisCache()
+ # Create a cache for storing recipients of the alternate welcome DM.
+ member_gating_cache = RedisCache()
+
def __init__(self, bot: Bot) -> None:
"""Start internal tasks."""
self.bot = bot
@@ -335,6 +356,28 @@ class Verification(Cog):
return n_success
+ async def _add_kick_note(self, member: discord.Member) -> None:
+ """
+ Post a note regarding `member` being kicked to site.
+
+ Allows keeping track of kicked members for auditing purposes.
+ """
+ payload = {
+ "active": False,
+ "actor": self.bot.user.id, # Bot actions this autonomously
+ "expires_at": None,
+ "hidden": True,
+ "reason": "Verification kick",
+ "type": "note",
+ "user": member.id,
+ }
+
+ log.trace(f"Posting kick note for member {member} ({member.id})")
+ try:
+ await self.bot.api_client.post("bot/infractions", json=payload)
+ except ResponseCodeError as api_exc:
+ log.warning("Failed to post kick note", exc_info=api_exc)
+
async def _kick_members(self, members: t.Collection[discord.Member]) -> int:
"""
Kick `members` from the PyDis guild.
@@ -353,6 +396,7 @@ class Verification(Cog):
except discord.HTTPException as suspicious_exception:
raise StopExecution(reason=suspicious_exception)
await member.kick(reason=f"User has not verified in {constants.Verification.kicked_after} days")
+ await self._add_kick_note(member)
n_kicked = await self._send_requests(members, kick_request, Limit(batch_size=2, sleep_secs=1))
self.bot.stats.incr("verification.kicked", count=n_kicked)
@@ -519,6 +563,26 @@ class Verification(Cog):
if member.guild.id != constants.Guild.id:
return # Only listen for PyDis events
+ raw_member = await self.bot.http.get_member(member.guild.id, member.id)
+
+ # If the user has the is_pending flag set, they will be using the alternate
+ # gate and will not need a welcome DM with verification instructions.
+ # We will send them an alternate DM once they verify with the welcome
+ # video.
+ if raw_member.get("is_pending"):
+ await self.member_gating_cache.set(member.id, True)
+
+ # TODO: Temporary, remove soon after asking joe.
+ await self.mod_log.send_log_message(
+ icon_url=self.bot.user.avatar_url,
+ colour=discord.Colour.blurple(),
+ title="New native gated user",
+ channel_id=constants.Channels.user_log,
+ text=f"<@{member.id}> ({member.id})",
+ )
+
+ return
+
log.trace(f"Sending on join message to new member: {member.id}")
try:
await safe_dm(member.send(ON_JOIN_MESSAGE))
@@ -526,6 +590,23 @@ class Verification(Cog):
log.exception("DM dispatch failed on unexpected error code")
@Cog.listener()
+ async def on_member_update(self, before: discord.Member, after: discord.Member) -> None:
+ """Check if we need to send a verification DM to a gated user."""
+ before_roles = [role.id for role in before.roles]
+ after_roles = [role.id for role in after.roles]
+
+ if constants.Roles.verified not in before_roles and constants.Roles.verified in after_roles:
+ if await self.member_gating_cache.pop(after.id):
+ try:
+ # If the member has not received a DM from our !accept command
+ # and has gone through the alternate gating system we should send
+ # our alternate welcome DM which includes info such as our welcome
+ # video.
+ await safe_dm(after.send(ALTERNATE_VERIFIED_MESSAGE))
+ except discord.HTTPException:
+ log.exception("DM dispatch failed on unexpected error code")
+
+ @Cog.listener()
async def on_message(self, message: discord.Message) -> None:
"""Check new message event for messages to the checkpoint channel & process."""
if message.channel.id != constants.Channels.verification:
diff --git a/bot/exts/moderation/voice_gate.py b/bot/exts/moderation/voice_gate.py
new file mode 100644
index 000000000..c2743e136
--- /dev/null
+++ b/bot/exts/moderation/voice_gate.py
@@ -0,0 +1,168 @@
+import asyncio
+import logging
+from contextlib import suppress
+from datetime import datetime, timedelta
+
+import discord
+from dateutil import parser
+from discord import Colour
+from discord.ext.commands import Cog, Context, command
+
+from bot.api import ResponseCodeError
+from bot.bot import Bot
+from bot.constants import Channels, Event, MODERATION_ROLES, Roles, VoiceGate as GateConf
+from bot.decorators import has_no_roles, in_whitelist
+from bot.exts.moderation.modlog import ModLog
+from bot.utils.checks import InWhitelistCheckFailure
+
+log = logging.getLogger(__name__)
+
+FAILED_MESSAGE = (
+ """You are not currently eligible to use voice inside Python Discord for the following reasons:\n\n{reasons}"""
+)
+
+MESSAGE_FIELD_MAP = {
+ "verified_at": f"have been verified for less than {GateConf.minimum_days_verified} days",
+ "voice_banned": "have an active voice ban infraction",
+ "total_messages": f"have sent less than {GateConf.minimum_messages} messages",
+}
+
+
+class VoiceGate(Cog):
+ """Voice channels verification management."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ @property
+ def mod_log(self) -> ModLog:
+ """Get the currently loaded ModLog cog instance."""
+ return self.bot.get_cog("ModLog")
+
+ @command(aliases=('voiceverify',))
+ @has_no_roles(Roles.voice_verified)
+ @in_whitelist(channels=(Channels.voice_gate,), redirect=None)
+ async def voice_verify(self, ctx: Context, *_) -> None:
+ """
+ Apply to be able to use voice within the Discord server.
+
+ In order to use voice you must meet all three of the following criteria:
+ - You must have over a certain number of messages within the Discord server
+ - You must have accepted our rules over a certain number of days ago
+ - You must not be actively banned from using our voice channels
+ """
+ try:
+ data = await self.bot.api_client.get(f"bot/users/{ctx.author.id}/metricity_data")
+ except ResponseCodeError as e:
+ if e.status == 404:
+ embed = discord.Embed(
+ title="Not found",
+ description=(
+ "We were unable to find user data for you. "
+ "Please try again shortly, "
+ "if this problem persists please contact the server staff through Modmail.",
+ ),
+ color=Colour.red()
+ )
+ log.info(f"Unable to find Metricity data about {ctx.author} ({ctx.author.id})")
+ else:
+ embed = discord.Embed(
+ title="Unexpected response",
+ description=(
+ "We encountered an error while attempting to find data for your user. "
+ "Please try again and let us know if the problem persists."
+ ),
+ color=Colour.red()
+ )
+ log.warning(f"Got response code {e.status} while trying to get {ctx.author.id} Metricity data.")
+
+ await ctx.author.send(embed=embed)
+ return
+
+ # Pre-parse this for better code style
+ if data["verified_at"] is not None:
+ data["verified_at"] = parser.isoparse(data["verified_at"])
+ else:
+ data["verified_at"] = datetime.utcnow() - timedelta(days=3)
+
+ checks = {
+ "verified_at": data["verified_at"] > datetime.utcnow() - timedelta(days=GateConf.minimum_days_verified),
+ "total_messages": data["total_messages"] < GateConf.minimum_messages,
+ "voice_banned": data["voice_banned"]
+ }
+ failed = any(checks.values())
+ failed_reasons = [MESSAGE_FIELD_MAP[key] for key, value in checks.items() if value is True]
+ [self.bot.stats.incr(f"voice_gate.failed.{key}") for key, value in checks.items() if value is True]
+
+ if failed:
+ embed = discord.Embed(
+ title="Voice Gate failed",
+ description=FAILED_MESSAGE.format(reasons="\n".join(f'• You {reason}.' for reason in failed_reasons)),
+ color=Colour.red()
+ )
+ try:
+ await ctx.author.send(embed=embed)
+ await ctx.send(f"{ctx.author}, please check your DMs.")
+ except discord.Forbidden:
+ await ctx.channel.send(ctx.author.mention, embed=embed)
+ return
+
+ self.mod_log.ignore(Event.member_update, ctx.author.id)
+ embed = discord.Embed(
+ title="Voice gate passed",
+ description="You have been granted permission to use voice channels in Python Discord.",
+ color=Colour.green()
+ )
+
+ if ctx.author.voice:
+ embed.description += "\n\nPlease reconnect to your voice channel to be granted your new permissions."
+
+ try:
+ await ctx.author.send(embed=embed)
+ await ctx.send(f"{ctx.author}, please check your DMs.")
+ except discord.Forbidden:
+ await ctx.channel.send(ctx.author.mention, embed=embed)
+
+ # wait a little bit so those who don't get DMs see the response in-channel before losing perms to see it.
+ await asyncio.sleep(3)
+ await ctx.author.add_roles(discord.Object(Roles.voice_verified), reason="Voice Gate passed")
+
+ self.bot.stats.incr("voice_gate.passed")
+
+ @Cog.listener()
+ async def on_message(self, message: discord.Message) -> None:
+ """Delete all non-staff messages from voice gate channel that don't invoke voice verify command."""
+ # Check is channel voice gate
+ if message.channel.id != Channels.voice_gate:
+ return
+
+ ctx = await self.bot.get_context(message)
+ is_verify_command = ctx.command is not None and ctx.command.name == "voice_verify"
+
+ # When it's bot sent message, delete it after some time
+ if message.author.bot:
+ with suppress(discord.NotFound):
+ await message.delete(delay=GateConf.bot_message_delete_delay)
+ return
+
+ # Then check is member moderator+, because we don't want to delete their messages.
+ if any(role.id in MODERATION_ROLES for role in message.author.roles) and is_verify_command is False:
+ log.trace(f"Excluding moderator message {message.id} from deletion in #{message.channel}.")
+ return
+
+ # Ignore deleted voice verification messages
+ if ctx.command is not None and ctx.command.name == "voice_verify":
+ self.mod_log.ignore(Event.message_delete, message.id)
+
+ with suppress(discord.NotFound):
+ await message.delete()
+
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Check for & ignore any InWhitelistCheckFailure."""
+ if isinstance(error, InWhitelistCheckFailure):
+ error.handled = True
+
+
+def setup(bot: Bot) -> None:
+ """Loads the VoiceGate cog."""
+ bot.add_cog(VoiceGate(bot))
diff --git a/bot/exts/utils/bot.py b/bot/exts/utils/bot.py
index ba1fd2a5c..69d623581 100644
--- a/bot/exts/utils/bot.py
+++ b/bot/exts/utils/bot.py
@@ -1,22 +1,14 @@
-import ast
import logging
-import re
-import time
-from typing import Optional, Tuple
+from typing import Optional
-from discord import Embed, Message, RawMessageUpdateEvent, TextChannel
+from discord import Embed, TextChannel
from discord.ext.commands import Cog, Context, command, group, has_any_role
from bot.bot import Bot
-from bot.constants import Categories, Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs
-from bot.exts.filters.token_remover import TokenRemover
-from bot.exts.filters.webhook_remover import WEBHOOK_URL_RE
-from bot.utils.messages import wait_for_deletion
+from bot.constants import Guild, MODERATION_ROLES, Roles, URLs
log = logging.getLogger(__name__)
-RE_MARKDOWN = re.compile(r'([*_~`|>])')
-
class BotCog(Cog, name="Bot"):
"""Bot information commands."""
@@ -24,19 +16,6 @@ class BotCog(Cog, name="Bot"):
def __init__(self, bot: Bot):
self.bot = bot
- # Stores allowed channels plus epoch time since last call.
- self.channel_cooldowns = {
- Channels.python_discussion: 0,
- }
-
- # These channels will also work, but will not be subject to cooldown
- self.channel_whitelist = (
- Channels.bot_commands,
- )
-
- # Stores improperly formatted Python codeblock message ids and the corresponding bot message
- self.codeblock_message_ids = {}
-
@group(invoke_without_command=True, name="bot", hidden=True)
@has_any_role(Roles.verified)
async def botinfo_group(self, ctx: Context) -> None:
@@ -81,305 +60,6 @@ class BotCog(Cog, name="Bot"):
else:
await channel.send(embed=embed)
- def codeblock_stripping(self, msg: str, bad_ticks: bool) -> Optional[Tuple[Tuple[str, ...], str]]:
- """
- Strip msg in order to find Python code.
-
- Tries to strip out Python code out of msg and returns the stripped block or
- None if the block is a valid Python codeblock.
- """
- if msg.count("\n") >= 3:
- # Filtering valid Python codeblocks and exiting if a valid Python codeblock is found.
- if re.search("```(?:py|python)\n(.*?)```", msg, re.IGNORECASE | re.DOTALL) and not bad_ticks:
- log.trace(
- "Someone wrote a message that was already a "
- "valid Python syntax highlighted code block. No action taken."
- )
- return None
-
- else:
- # Stripping backticks from every line of the message.
- log.trace(f"Stripping backticks from message.\n\n{msg}\n\n")
- content = ""
- for line in msg.splitlines(keepends=True):
- content += line.strip("`")
-
- content = content.strip()
-
- # Remove "Python" or "Py" from start of the message if it exists.
- log.trace(f"Removing 'py' or 'python' from message.\n\n{content}\n\n")
- pycode = False
- if content.lower().startswith("python"):
- content = content[6:]
- pycode = True
- elif content.lower().startswith("py"):
- content = content[2:]
- pycode = True
-
- if pycode:
- content = content.splitlines(keepends=True)
-
- # Check if there might be code in the first line, and preserve it.
- first_line = content[0]
- if " " in content[0]:
- first_space = first_line.index(" ")
- content[0] = first_line[first_space:]
- content = "".join(content)
-
- # If there's no code we can just get rid of the first line.
- else:
- content = "".join(content[1:])
-
- # Strip it again to remove any leading whitespace. This is necessary
- # if the first line of the message looked like ```python <code>
- old = content.strip()
-
- # Strips REPL code out of the message if there is any.
- content, repl_code = self.repl_stripping(old)
- if old != content:
- return (content, old), repl_code
-
- # Try to apply indentation fixes to the code.
- content = self.fix_indentation(content)
-
- # Check if the code contains backticks, if it does ignore the message.
- if "`" in content:
- log.trace("Detected ` inside the code, won't reply")
- return None
- else:
- log.trace(f"Returning message.\n\n{content}\n\n")
- return (content,), repl_code
-
- def fix_indentation(self, msg: str) -> str:
- """Attempts to fix badly indented code."""
- def unindent(code: str, skip_spaces: int = 0) -> str:
- """Unindents all code down to the number of spaces given in skip_spaces."""
- final = ""
- current = code[0]
- leading_spaces = 0
-
- # Get numbers of spaces before code in the first line.
- while current == " ":
- current = code[leading_spaces + 1]
- leading_spaces += 1
- leading_spaces -= skip_spaces
-
- # If there are any, remove that number of spaces from every line.
- if leading_spaces > 0:
- for line in code.splitlines(keepends=True):
- line = line[leading_spaces:]
- final += line
- return final
- else:
- return code
-
- # Apply fix for "all lines are overindented" case.
- msg = unindent(msg)
-
- # If the first line does not end with a colon, we can be
- # certain the next line will be on the same indentation level.
- #
- # If it does end with a colon, we will need to indent all successive
- # lines one additional level.
- first_line = msg.splitlines()[0]
- code = "".join(msg.splitlines(keepends=True)[1:])
- if not first_line.endswith(":"):
- msg = f"{first_line}\n{unindent(code)}"
- else:
- msg = f"{first_line}\n{unindent(code, 4)}"
- return msg
-
- def repl_stripping(self, msg: str) -> Tuple[str, bool]:
- """
- Strip msg in order to extract Python code out of REPL output.
-
- Tries to strip out REPL Python code out of msg and returns the stripped msg.
-
- Returns True for the boolean if REPL code was found in the input msg.
- """
- final = ""
- for line in msg.splitlines(keepends=True):
- if line.startswith(">>>") or line.startswith("..."):
- final += line[4:]
- log.trace(f"Formatted: \n\n{msg}\n\n to \n\n{final}\n\n")
- if not final:
- log.trace(f"Found no REPL code in \n\n{msg}\n\n")
- return msg, False
- else:
- log.trace(f"Found REPL code in \n\n{msg}\n\n")
- return final.rstrip(), True
-
- def has_bad_ticks(self, msg: Message) -> bool:
- """Check to see if msg contains ticks that aren't '`'."""
- not_backticks = [
- "'''", '"""', "\u00b4\u00b4\u00b4", "\u2018\u2018\u2018", "\u2019\u2019\u2019",
- "\u2032\u2032\u2032", "\u201c\u201c\u201c", "\u201d\u201d\u201d", "\u2033\u2033\u2033",
- "\u3003\u3003\u3003"
- ]
-
- return msg.content[:3] in not_backticks
-
- @Cog.listener()
- async def on_message(self, msg: Message) -> None:
- """
- Detect poorly formatted Python code in new messages.
-
- If poorly formatted code is detected, send the user a helpful message explaining how to do
- properly formatted Python syntax highlighting codeblocks.
- """
- is_help_channel = (
- getattr(msg.channel, "category", None)
- and msg.channel.category.id in (Categories.help_available, Categories.help_in_use)
- )
- parse_codeblock = (
- (
- is_help_channel
- or msg.channel.id in self.channel_cooldowns
- or msg.channel.id in self.channel_whitelist
- )
- and not msg.author.bot
- and len(msg.content.splitlines()) > 3
- and not TokenRemover.find_token_in_message(msg)
- and not WEBHOOK_URL_RE.search(msg.content)
- )
-
- if parse_codeblock: # no token in the msg
- on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300
- if not on_cooldown or DEBUG_MODE:
- try:
- if self.has_bad_ticks(msg):
- ticks = msg.content[:3]
- content = self.codeblock_stripping(f"```{msg.content[3:-3]}```", True)
- if content is None:
- return
-
- content, repl_code = content
-
- if len(content) == 2:
- content = content[1]
- else:
- content = content[0]
-
- space_left = 204
- if len(content) >= space_left:
- current_length = 0
- lines_walked = 0
- for line in content.splitlines(keepends=True):
- if current_length + len(line) > space_left or lines_walked == 10:
- break
- current_length += len(line)
- lines_walked += 1
- content = content[:current_length] + "#..."
- content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content)
- howto = (
- "It looks like you are trying to paste code into this channel.\n\n"
- "You seem to be using the wrong symbols to indicate where the codeblock should start. "
- f"The correct symbols would be \\`\\`\\`, not `{ticks}`.\n\n"
- "**Here is an example of how it should look:**\n"
- f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n"
- "**This will result in the following:**\n"
- f"```python\n{content}\n```"
- )
-
- else:
- howto = ""
- content = self.codeblock_stripping(msg.content, False)
- if content is None:
- return
-
- content, repl_code = content
- # Attempts to parse the message into an AST node.
- # Invalid Python code will raise a SyntaxError.
- tree = ast.parse(content[0])
-
- # Multiple lines of single words could be interpreted as expressions.
- # This check is to avoid all nodes being parsed as expressions.
- # (e.g. words over multiple lines)
- if not all(isinstance(node, ast.Expr) for node in tree.body) or repl_code:
- # Shorten the code to 10 lines and/or 204 characters.
- space_left = 204
- if content and repl_code:
- content = content[1]
- else:
- content = content[0]
-
- if len(content) >= space_left:
- current_length = 0
- lines_walked = 0
- for line in content.splitlines(keepends=True):
- if current_length + len(line) > space_left or lines_walked == 10:
- break
- current_length += len(line)
- lines_walked += 1
- content = content[:current_length] + "#..."
-
- content_escaped_markdown = RE_MARKDOWN.sub(r'\\\1', content)
- howto += (
- "It looks like you're trying to paste code into this channel.\n\n"
- "Discord has support for Markdown, which allows you to post code with full "
- "syntax highlighting. Please use these whenever you paste code, as this "
- "helps improve the legibility and makes it easier for us to help you.\n\n"
- f"**To do this, use the following method:**\n"
- f"\\`\\`\\`python\n{content_escaped_markdown}\n\\`\\`\\`\n\n"
- "**This will result in the following:**\n"
- f"```python\n{content}\n```"
- )
-
- log.debug(f"{msg.author} posted something that needed to be put inside python code "
- "blocks. Sending the user some instructions.")
- else:
- log.trace("The code consists only of expressions, not sending instructions")
-
- if howto != "":
- # Increase amount of codeblock correction in stats
- self.bot.stats.incr("codeblock_corrections")
- howto_embed = Embed(description=howto)
- bot_message = await msg.channel.send(f"Hey {msg.author.mention}!", embed=howto_embed)
- self.codeblock_message_ids[msg.id] = bot_message.id
-
- self.bot.loop.create_task(
- wait_for_deletion(bot_message, (msg.author.id,), self.bot)
- )
- else:
- return
-
- if msg.channel.id not in self.channel_whitelist:
- self.channel_cooldowns[msg.channel.id] = time.time()
-
- except SyntaxError:
- log.trace(
- f"{msg.author} posted in a help channel, and when we tried to parse it as Python code, "
- "ast.parse raised a SyntaxError. This probably just means it wasn't Python code. "
- f"The message that was posted was:\n\n{msg.content}\n\n"
- )
-
- @Cog.listener()
- async def on_raw_message_edit(self, payload: RawMessageUpdateEvent) -> None:
- """Check to see if an edited message (previously called out) still contains poorly formatted code."""
- if (
- # Checks to see if the message was called out by the bot
- payload.message_id not in self.codeblock_message_ids
- # Makes sure that there is content in the message
- or payload.data.get("content") is None
- # Makes sure there's a channel id in the message payload
- or payload.data.get("channel_id") is None
- ):
- return
-
- # Retrieve channel and message objects for use later
- channel = self.bot.get_channel(int(payload.data.get("channel_id")))
- user_message = await channel.fetch_message(payload.message_id)
-
- # Checks to see if the user has corrected their codeblock. If it's fixed, has_fixed_codeblock will be None
- has_fixed_codeblock = self.codeblock_stripping(payload.data.get("content"), self.has_bad_ticks(user_message))
-
- # If the message is fixed, delete the bot message and the entry from the id dictionary
- if has_fixed_codeblock is None:
- bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id])
- await bot_message.delete()
- del self.codeblock_message_ids[payload.message_id]
- log.trace("User's incorrect code block has been fixed. Removing bot formatting message.")
-
def setup(bot: Bot) -> None:
"""Load the Bot cog."""
diff --git a/bot/exts/utils/ping.py b/bot/exts/utils/ping.py
index a9ca3dbeb..572fc934b 100644
--- a/bot/exts/utils/ping.py
+++ b/bot/exts/utils/ping.py
@@ -33,7 +33,7 @@ class Latency(commands.Cog):
"""
# datetime.datetime objects do not have the "milliseconds" attribute.
# It must be converted to seconds before converting to milliseconds.
- bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() / 1000
+ bot_ping = (datetime.utcnow() - ctx.message.created_at).total_seconds() * 1000
bot_ping = f"{bot_ping:.{ROUND_LATENCY}f} ms"
try:
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index bf4e24661..3113a1149 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -23,7 +23,7 @@ from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
-NAMESPACE = "reminder" # Used for the mutually_exclusive decorator; constant to prevent typos
+LOCK_NAMESPACE = "reminder"
WHITELISTED_CHANNELS = Guild.reminder_whitelist
MAXIMUM_REMINDERS = 5
@@ -170,7 +170,7 @@ class Reminders(Cog):
log.trace(f"Scheduling new task #{reminder['id']}")
self.schedule_reminder(reminder)
- @lock_arg(NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "reminder", itemgetter("id"), raise_error=True)
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
is_valid, user, channel = self.ensure_valid_reminder(reminder)
@@ -378,7 +378,7 @@ class Reminders(Cog):
mention_ids = [mention.id for mention in mentions]
await self.edit_reminder(ctx, id_, {"mentions": mention_ids})
- @lock_arg(NAMESPACE, "id_", raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)
async def edit_reminder(self, ctx: Context, id_: int, payload: dict) -> None:
"""Edits a reminder with the given payload, then sends a confirmation message."""
if not await self._can_modify(ctx, id_):
@@ -398,7 +398,7 @@ class Reminders(Cog):
await self._reschedule_reminder(reminder)
@remind_group.command("delete", aliases=("remove", "cancel"))
- @lock_arg(NAMESPACE, "id_", raise_error=True)
+ @lock_arg(LOCK_NAMESPACE, "id_", raise_error=True)
async def delete_reminder(self, ctx: Context, id_: int) -> None:
"""Delete one of your active reminders."""
if not await self._can_modify(ctx, id_):
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index da3e07f42..41cb00541 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -36,11 +36,11 @@ RAW_CODE_REGEX = re.compile(
re.DOTALL # "." also matches newlines
)
-MAX_PASTE_LEN = 1000
+MAX_PASTE_LEN = 10000
# `!eval` command whitelists
-EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric, Channels.code_help_voice)
-EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use)
+EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric)
+EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice)
EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
SIGKILL = 9
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 60170a88f..13533a467 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -1,4 +1,4 @@
-from bot.utils.helpers import CogABCMeta, find_nth_occurrence, pad_base64
+from bot.utils.helpers import CogABCMeta, find_nth_occurrence, has_lines, pad_base64
from bot.utils.services import send_to_paste_service
-__all__ = ['CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service']
+__all__ = ['CogABCMeta', 'find_nth_occurrence', 'has_lines', 'pad_base64', 'send_to_paste_service']
diff --git a/bot/utils/channel.py b/bot/utils/channel.py
new file mode 100644
index 000000000..6bf70bfde
--- /dev/null
+++ b/bot/utils/channel.py
@@ -0,0 +1,49 @@
+import logging
+
+import discord
+
+from bot import constants
+from bot.constants import Categories
+
+log = logging.getLogger(__name__)
+
+
+def is_help_channel(channel: discord.TextChannel) -> bool:
+ """Return True if `channel` is in one of the help categories (excluding dormant)."""
+ log.trace(f"Checking if #{channel} is a help channel.")
+ categories = (Categories.help_available, Categories.help_in_use)
+
+ return any(is_in_category(channel, category) for category in categories)
+
+
+def is_mod_channel(channel: discord.TextChannel) -> bool:
+ """True if `channel` is considered a mod channel."""
+ if channel.id in constants.MODERATION_CHANNELS:
+ log.trace(f"Channel #{channel} is a configured mod channel")
+ return True
+
+ elif any(is_in_category(channel, category) for category in constants.MODERATION_CATEGORIES):
+ log.trace(f"Channel #{channel} is in a configured mod category")
+ return True
+
+ else:
+ log.trace(f"Channel #{channel} is not a mod channel")
+ return False
+
+
+def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
+ """Return True if `channel` is within a category with `category_id`."""
+ return getattr(channel, "category_id", None) == category_id
+
+
+async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel:
+ """Attempt to get or fetch a channel and return it."""
+ log.trace(f"Getting the channel {channel_id}.")
+
+ channel = client.get_channel(channel_id)
+ if not channel:
+ log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
+ channel = await client.fetch_channel(channel_id)
+
+ log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
+ return channel
diff --git a/bot/utils/helpers.py b/bot/utils/helpers.py
index d9b60af07..3501a3933 100644
--- a/bot/utils/helpers.py
+++ b/bot/utils/helpers.py
@@ -18,6 +18,15 @@ def find_nth_occurrence(string: str, substring: str, n: int) -> Optional[int]:
return index
+def has_lines(string: str, count: int) -> bool:
+ """Return True if `string` has at least `count` lines."""
+ # Benchmarks show this is significantly faster than using str.count("\n") or a for loop & break.
+ split = string.split("\n", count - 1)
+
+ # Make sure the last part isn't empty, which would happen if there was a final newline.
+ return split[-1] and len(split) == count
+
+
def pad_base64(data: str) -> str:
"""Return base64 `data` with padding characters to ensure its length is a multiple of 4."""
return data + "=" * (-len(data) % 4)
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index d0b2342b3..b6c7cab50 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -56,15 +56,24 @@ async def wait_for_deletion(
async def send_attachments(
message: discord.Message,
destination: Union[discord.TextChannel, discord.Webhook],
- link_large: bool = True
+ link_large: bool = True,
+ use_cached: bool = False,
+ **kwargs
) -> List[str]:
"""
Re-upload the message's attachments to the destination and return a list of their new URLs.
Each attachment is sent as a separate message to more easily comply with the request/file size
limit. If link_large is True, attachments which are too large are instead grouped into a single
- embed which links to them.
+ embed which links to them. Extra kwargs will be passed to send() when sending the attachment.
"""
+ webhook_send_kwargs = {
+ 'username': message.author.display_name,
+ 'avatar_url': message.author.avatar_url,
+ }
+ webhook_send_kwargs.update(kwargs)
+ webhook_send_kwargs['username'] = sub_clyde(webhook_send_kwargs['username'])
+
large = []
urls = []
for attachment in message.attachments:
@@ -78,18 +87,14 @@ async def send_attachments(
# but some may get through hence the try-catch.
if attachment.size <= destination.guild.filesize_limit - 512:
with BytesIO() as file:
- await attachment.save(file, use_cached=True)
+ await attachment.save(file, use_cached=use_cached)
attachment_file = discord.File(file, filename=attachment.filename)
if isinstance(destination, discord.TextChannel):
- msg = await destination.send(file=attachment_file)
+ msg = await destination.send(file=attachment_file, **kwargs)
urls.append(msg.attachments[0].url)
else:
- await destination.send(
- file=attachment_file,
- username=sub_clyde(message.author.display_name),
- avatar_url=message.author.avatar_url
- )
+ await destination.send(file=attachment_file, **webhook_send_kwargs)
elif link_large:
large.append(attachment)
else:
@@ -106,13 +111,9 @@ async def send_attachments(
embed.set_footer(text="Attachments exceed upload size limit.")
if isinstance(destination, discord.TextChannel):
- await destination.send(embed=embed)
+ await destination.send(embed=embed, **kwargs)
else:
- await destination.send(
- embed=embed,
- username=sub_clyde(message.author.display_name),
- avatar_url=message.author.avatar_url
- )
+ await destination.send(embed=embed, **webhook_send_kwargs)
return urls
diff --git a/config-default.yml b/config-default.yml
index 4f7b1e217..071f6e1ec 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -119,6 +119,7 @@ style:
voice_state_green: "https://cdn.discordapp.com/emojis/656899770094452754.png"
voice_state_red: "https://cdn.discordapp.com/emojis/656899769905709076.png"
+
guild:
id: 267624335836053506
invite: "https://discord.gg/python"
@@ -127,7 +128,9 @@ guild:
help_available: 691405807388196926
help_in_use: 696958401460043776
help_dormant: 691405908919451718
- modmail: 714494672835444826
+ modmail: &MODMAIL 714494672835444826
+ logs: &LOGS 468520609152892958
+ voice: 356013253765234688
channels:
# Public announcement and news channels
@@ -145,8 +148,8 @@ guild:
dev_log: &DEV_LOG 622895325144940554
# Discussion
- meta: 429409067623251969
- python_discussion: 267624335836053506
+ meta: 429409067623251969
+ python_discussion: &PY_DISCUSSION 267624335836053506
# Python Help: Available
how_to_get_help: 704250143020417084
@@ -169,6 +172,7 @@ guild:
bot_commands: &BOT_CMD 267659945086812160
esoteric: 470884583684964352
verification: 352442727016693763
+ voice_gate: 764802555427029012
# Staff
admins: &ADMINS 365960823622991872
@@ -178,7 +182,7 @@ guild:
incidents: 714214212200562749
incidents_archive: 720668923636351037
mods: &MODS 305126844661760000
- mod_alerts: &MOD_ALERTS 473092532147060736
+ mod_alerts: 473092532147060736
mod_spam: &MOD_SPAM 620607373828030464
organisation: &ORGANISATION 551789653284356126
staff_lounge: &STAFF_LOUNGE 464905259261755392
@@ -191,6 +195,8 @@ guild:
# Voice
code_help_voice: 755154969761677312
+ code_help_voice_2: 766330079135268884
+ voice_chat: 412357430186344448
admins_voice: &ADMINS_VOICE 500734494840717332
staff_voice: &STAFF_VOICE 412375055910043655
@@ -198,10 +204,13 @@ guild:
big_brother_logs: &BB_LOGS 468507907357409333
talent_pool: &TALENT_POOL 534321732593647616
+ moderation_categories:
+ - *MODMAIL
+ - *LOGS
+
moderation_channels:
- *ADMINS
- *ADMIN_SPAM
- - *MOD_ALERTS
- *MODS
- *MOD_SPAM
@@ -225,9 +234,11 @@ guild:
muted: &MUTED_ROLE 277914926603829249
partners: 323426753857191936
python_community: &PY_COMMUNITY_ROLE 458226413825294336
+ sprinters: &SPRINTERS 758422482289426471
unverified: 739794855945044069
verified: 352427296948486144 # @Developers on PyDis
+ voice_verified: 764802720779337729
# Staff
admins: &ADMINS_ROLE 267628507062992896
@@ -261,6 +272,7 @@ guild:
reddit: 635408384794951680
talent_pool: 569145364800602132
+
filter:
# What do we filter?
filter_zalgo: false
@@ -298,6 +310,7 @@ filter:
- *OWNERS_ROLE
- *HELPERS_ROLE
- *PY_COMMUNITY_ROLE
+ - *SPRINTERS
keys:
@@ -326,6 +339,7 @@ urls:
bot_avatar: "https://raw.githubusercontent.com/discord-python/branding/master/logos/logo_circle/logo_circle.png"
github_bot_repo: "https://github.com/python-discord/bot"
+
anti_spam:
# Clean messages that violate a rule.
clean_offending: true
@@ -394,6 +408,23 @@ big_brother:
header_message_limit: 15
+code_block:
+ # The channels in which code blocks will be detected. They are not subject to a cooldown.
+ channel_whitelist:
+ - *BOT_CMD
+
+ # The channels which will be affected by a cooldown. These channels are also whitelisted.
+ cooldown_channels:
+ - *PY_DISCUSSION
+
+ # Sending instructions triggers a cooldown on a per-channel basis.
+ # More instruction messages will not be sent in the same channel until the cooldown has elapsed.
+ cooldown_seconds: 300
+
+ # The minimum amount of lines a message or code block must have for instructions to be sent.
+ minimum_lines: 4
+
+
free:
# Seconds to elapse for a channel
# to be considered inactive.
@@ -442,10 +473,12 @@ help_channels:
notify_roles:
- *HELPERS_ROLE
+
redirect_output:
delete_invocation: true
delete_delay: 15
+
duck_pond:
threshold: 4
channel_blacklist:
@@ -461,11 +494,13 @@ duck_pond:
- *MOD_ANNOUNCEMENTS
- *ADMIN_ANNOUNCEMENTS
+
python_news:
mail_lists:
- 'python-ideas'
- 'python-announce-list'
- 'pypi-announce'
+ - 'python-dev'
channel: *PYNEWS_CHANNEL
webhook: *PYNEWS_WEBHOOK
@@ -482,5 +517,11 @@ verification:
kick_confirmation_threshold: 0.01 # 1%
+voice_gate:
+ minimum_days_verified: 3 # How many days the user must have been verified for
+ minimum_messages: 50 # How many messages a user must have to be eligible for voice
+ bot_message_delete_delay: 10 # Seconds before deleting bot's response in Voice Gate
+
+
config:
required_keys: ['bot.token']
diff --git a/docker-compose.yml b/docker-compose.yml
index cff7d33d6..8be5aac0e 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -41,6 +41,7 @@ services:
- postgres
environment:
DATABASE_URL: postgres://pysite:pysite@postgres:5432/pysite
+ METRICITY_DB_URL: postgres://pysite:pysite@postgres:5432/metricity
SECRET_KEY: suitable-for-development-only
STATIC_ROOT: /var/www/static
diff --git a/tests/_autospec.py b/tests/_autospec.py
new file mode 100644
index 000000000..ee2fc1973
--- /dev/null
+++ b/tests/_autospec.py
@@ -0,0 +1,64 @@
+import contextlib
+import functools
+import unittest.mock
+from typing import Callable
+
+
[email protected](unittest.mock._patch.decoration_helper)
+def _decoration_helper(self, patched, args, keywargs):
+ """Skips adding patchings as args if their `dont_pass` attribute is True."""
+ # Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added.
+ extra_args = []
+ with contextlib.ExitStack() as exit_stack:
+ for patching in patched.patchings:
+ arg = exit_stack.enter_context(patching)
+ if not getattr(patching, "dont_pass", False):
+ # Only add the patching as an arg if dont_pass is False.
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is unittest.mock.DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield args, keywargs
+
+
[email protected](unittest.mock._patch.copy)
+def _copy(self):
+ """Copy the `dont_pass` attribute along with the standard copy operation."""
+ patcher_copy = _copy.original(self)
+ patcher_copy.dont_pass = getattr(self, "dont_pass", False)
+ return patcher_copy
+
+
+# Monkey-patch the patcher class :)
+_copy.original = unittest.mock._patch.copy
+unittest.mock._patch.copy = _copy
+unittest.mock._patch.decoration_helper = _decoration_helper
+
+
+def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable:
+ """
+ Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.
+
+ If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object.
+ """
+ # Caller's kwargs should take priority and overwrite the defaults.
+ kwargs = dict(spec_set=True, autospec=True)
+ kwargs.update(patch_kwargs)
+
+ # Import the target if it's a string.
+ # This is to support both object and string targets like patch.multiple.
+ if type(target) is str:
+ target = unittest.mock._importer(target)
+
+ def decorator(func):
+ for attribute in attributes:
+ patcher = unittest.mock.patch.object(target, attribute, **kwargs)
+ if not pass_mocks:
+ # A custom attribute to keep track of which patchings should be skipped.
+ patcher.dont_pass = True
+ func = patcher(func)
+ return func
+ return decorator
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index c0a1da35c..9f380a15d 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,7 +1,6 @@
import unittest
-from unittest import mock
-from bot.exts.backend.sync._syncers import UserSyncer, _Diff, _User
+from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -10,7 +9,7 @@ def fake_user(**kwargs):
kwargs.setdefault("id", 43)
kwargs.setdefault("name", "bob the test man")
kwargs.setdefault("discriminator", 1337)
- kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("roles", [666])
kwargs.setdefault("in_guild", True)
return kwargs
@@ -40,22 +39,42 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
return guild
+ @staticmethod
+ def get_mock_member(member: dict):
+ member = member.copy()
+ del member["in_guild"]
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+ return mock_member
+
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": []
+ }
guild = self.get_guild()
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.return_value = self.get_mock_member(fake_user())
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -63,59 +82,102 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(id=99, name="old"), fake_user()]
+ }
guild = self.get_guild(updated_user, fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(updated_user),
+ self.get_mock_member(fake_user())
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**updated_user)}, None)
+ expected_diff = ([], [{"id": 99, "name": "new"}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_users(self):
- """Only new users should be added to the 'created' set of the diff."""
+ """Only new users should be added to the 'created' list of the diff."""
new_user = fake_user(id=99, name="new")
- self.bot.api_client.get.return_value = [fake_user()]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user()]
+ }
guild = self.get_guild(fake_user(), new_user)
-
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(new_user)
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, set(), None)
+ expected_diff = ([new_user], [], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- leaving_user = fake_user(id=63, in_guild=False)
-
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), {_User(**leaving_user)}, None)
+ expected_diff = ([], [{"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
+
updated_user = fake_user(id=55, name="updated")
- leaving_user = fake_user(id=63, in_guild=False)
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=55), fake_user(id=63)]
+ }
guild = self.get_guild(fake_user(), new_user, updated_user)
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ self.get_mock_member(updated_user),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+ expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
async def test_empty_diff_for_db_users_not_in_guild(self):
- """When the DB knows a user the guild doesn't, no difference is found."""
- self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ """When the DB knows a user, but the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = {
+ "count": 3,
+ "next_page_no": None,
+ "previous_page_no": None,
+ "results": [fake_user(), fake_user(id=63, in_guild=False)]
+ }
guild = self.get_guild(fake_user())
+ guild.get_member.side_effect = [
+ self.get_mock_member(fake_user()),
+ None
+ ]
actual_diff = await self.syncer._get_diff(guild)
- expected_diff = (set(), set(), None)
+ expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -131,13 +193,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(user_tuples, set(), None)
+ diff = _Diff(users, [], None)
await self.syncer._sync(diff)
- calls = [mock.call("bot/users", json=user) for user in users]
- self.bot.api_client.post.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.post.call_count, len(users))
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
@@ -146,13 +205,10 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
- user_tuples = {_User(**user) for user in users}
- diff = _Diff(set(), user_tuples, None)
+ diff = _Diff([], users, None)
await self.syncer._sync(diff)
- calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
- self.bot.api_client.put.assert_has_calls(calls, any_order=True)
- self.assertEqual(self.bot.api_client.put.call_count, len(users))
+ self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index 36a35c8e2..daede54c5 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -92,77 +92,6 @@ class InformationCogTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(admin_embed.title, "Admins info")
self.assertEqual(admin_embed.colour, discord.Colour.red())
- @unittest.mock.patch('bot.exts.info.information.time_since')
- async def test_server_info_command(self, time_since_patch):
- time_since_patch.return_value = '2 days ago'
-
- self.ctx.guild = helpers.MockGuild(
- features=('lemons', 'apples'),
- region="The Moon",
- roles=[self.moderator_role],
- channels=[
- discord.TextChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'}
- ),
- discord.CategoryChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'}
- ),
- discord.VoiceChannel(
- state={},
- guild=self.ctx.guild,
- data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'}
- )
- ],
- members=[
- *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
- *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
- *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
- *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
- ],
- member_count=1_234,
- icon_url='a-lemon.jpg',
- )
-
- self.assertIsNone(await self.cog.server_info(self.cog, self.ctx))
-
- time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days')
- _, kwargs = self.ctx.send.call_args
- embed = kwargs.pop('embed')
- self.assertEqual(embed.colour, discord.Colour.blurple())
- self.assertEqual(
- embed.description,
- textwrap.dedent(
- f"""
- **Server information**
- Created: {time_since_patch.return_value}
- Voice region: {self.ctx.guild.region}
- Features: {', '.join(self.ctx.guild.features)}
-
- **Channel counts**
- Category channels: 1
- Text channels: 1
- Voice channels: 1
- Staff channels: 0
-
- **Member counts**
- Members: {self.ctx.guild.member_count:,}
- Staff members: 0
- Roles: {len(self.ctx.guild.roles)}
-
- **Member statuses**
- {constants.Emojis.status_online} 2
- {constants.Emojis.status_idle} 1
- {constants.Emojis.status_dnd} 4
- {constants.Emojis.status_offline} 3
- """
- )
- )
- self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg')
-
class UserInfractionHelperMethodTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the helper methods of the `!user` command."""
@@ -465,7 +394,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
"basic infractions info",
- embed.fields[3].value
+ embed.fields[2].value
)
@unittest.mock.patch(
diff --git a/tests/bot/exts/moderation/infraction/test_infractions.py b/tests/bot/exts/moderation/infraction/test_infractions.py
index be1b649e1..bf557a484 100644
--- a/tests/bot/exts/moderation/infraction/test_infractions.py
+++ b/tests/bot/exts/moderation/infraction/test_infractions.py
@@ -1,7 +1,8 @@
import textwrap
import unittest
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, patch
+from bot.constants import Event
from bot.exts.moderation.infraction.infractions import Infractions
from tests.helpers import MockBot, MockContext, MockGuild, MockMember, MockRole
@@ -53,3 +54,148 @@ class TruncationTests(unittest.IsolatedAsyncioTestCase):
self.cog.apply_infraction.assert_awaited_once_with(
self.ctx, {"foo": "bar"}, self.target, self.target.kick.return_value
)
+
+
+@patch("bot.exts.moderation.infraction.infractions.constants.Roles.voice_verified", new=123456)
+class VoiceBanTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for voice ban related functions and commands."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.mod = MockMember(top_role=10)
+ self.user = MockMember(top_role=1, roles=[MockRole(id=123456)])
+ self.guild = MockGuild()
+ self.ctx = MockContext(bot=self.bot, author=self.mod)
+ self.cog = Infractions(self.bot)
+
+ async def test_permanent_voice_ban(self):
+ """Should call voice ban applying function without expiry."""
+ self.cog.apply_voice_ban = AsyncMock()
+ self.assertIsNone(await self.cog.voiceban(self.cog, self.ctx, self.user, reason="foobar"))
+ self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar")
+
+ async def test_temporary_voice_ban(self):
+ """Should call voice ban applying function with expiry."""
+ self.cog.apply_voice_ban = AsyncMock()
+ self.assertIsNone(await self.cog.tempvoiceban(self.cog, self.ctx, self.user, "baz", reason="foobar"))
+ self.cog.apply_voice_ban.assert_awaited_once_with(self.ctx, self.user, "foobar", expires_at="baz")
+
+ async def test_voice_unban(self):
+ """Should call infraction pardoning function."""
+ self.cog.pardon_infraction = AsyncMock()
+ self.assertIsNone(await self.cog.unvoiceban(self.cog, self.ctx, self.user))
+ self.cog.pardon_infraction.assert_awaited_once_with(self.ctx, "voice_ban", self.user)
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_user_have_active_infraction(self, get_active_infraction, post_infraction_mock):
+ """Should return early when user already have Voice Ban infraction."""
+ get_active_infraction.return_value = {"foo": "bar"}
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ get_active_infraction.assert_awaited_once_with(self.ctx, self.user, "voice_ban")
+ post_infraction_mock.assert_not_awaited()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_infraction_post_failed(self, get_active_infraction, post_infraction_mock):
+ """Should return early when posting infraction fails."""
+ self.cog.mod_log.ignore = MagicMock()
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = None
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ post_infraction_mock.assert_awaited_once()
+ self.cog.mod_log.ignore.assert_not_called()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_infraction_post_add_kwargs(self, get_active_infraction, post_infraction_mock):
+ """Should pass all kwargs passed to apply_voice_ban to post_infraction."""
+ get_active_infraction.return_value = None
+ # We don't want that this continue yet
+ post_infraction_mock.return_value = None
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar", my_kwarg=23))
+ post_infraction_mock.assert_awaited_once_with(
+ self.ctx, self.user, "voice_ban", "foobar", active=True, my_kwarg=23
+ )
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_mod_log_ignore(self, get_active_infraction, post_infraction_mock):
+ """Should ignore Voice Verified role removing."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.cog.mod_log.ignore.assert_called_once_with(Event.member_update, self.user.id)
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_apply_infraction(self, get_active_infraction, post_infraction_mock):
+ """Should ignore Voice Verified role removing."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar"))
+ self.user.remove_roles.assert_called_once_with(self.cog._voice_verified_role, reason="foobar")
+ self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value")
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.post_infraction")
+ @patch("bot.exts.moderation.infraction.infractions._utils.get_active_infraction")
+ async def test_voice_ban_truncate_reason(self, get_active_infraction, post_infraction_mock):
+ """Should truncate reason for voice ban."""
+ self.cog.mod_log.ignore = MagicMock()
+ self.cog.apply_infraction = AsyncMock()
+ self.user.remove_roles = MagicMock(return_value="my_return_value")
+
+ get_active_infraction.return_value = None
+ post_infraction_mock.return_value = {"foo": "bar"}
+
+ self.assertIsNone(await self.cog.apply_voice_ban(self.ctx, self.user, "foobar" * 3000))
+ self.user.remove_roles.assert_called_once_with(
+ self.cog._voice_verified_role, reason=textwrap.shorten("foobar" * 3000, 512, placeholder="...")
+ )
+ self.cog.apply_infraction.assert_awaited_once_with(self.ctx, {"foo": "bar"}, self.user, "my_return_value")
+
+ async def test_voice_unban_user_not_found(self):
+ """Should include info to return dict when user was not found from guild."""
+ self.guild.get_member.return_value = None
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {"Info": "User was not found in the guild."})
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
+ @patch("bot.exts.moderation.infraction.infractions.format_user")
+ async def test_voice_unban_user_found(self, format_user_mock, notify_pardon_mock):
+ """Should add role back with ignoring, notify user and return log dictionary.."""
+ self.guild.get_member.return_value = self.user
+ notify_pardon_mock.return_value = True
+ format_user_mock.return_value = "my-user"
+
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {
+ "Member": "my-user",
+ "DM": "Sent"
+ })
+ notify_pardon_mock.assert_awaited_once()
+
+ @patch("bot.exts.moderation.infraction.infractions._utils.notify_pardon")
+ @patch("bot.exts.moderation.infraction.infractions.format_user")
+ async def test_voice_unban_dm_fail(self, format_user_mock, notify_pardon_mock):
+ """Should add role back with ignoring, notify user and return log dictionary.."""
+ self.guild.get_member.return_value = self.user
+ notify_pardon_mock.return_value = False
+ format_user_mock.return_value = "my-user"
+
+ result = await self.cog.pardon_voice_ban(self.user.id, self.guild, "foobar")
+ self.assertEqual(result, {
+ "Member": "my-user",
+ "DM": "**Failed**"
+ })
+ notify_pardon_mock.assert_awaited_once()
diff --git a/tests/bot/exts/moderation/test_silence.py b/tests/bot/exts/moderation/test_silence.py
index 3c2d52ae0..104293d8e 100644
--- a/tests/bot/exts/moderation/test_silence.py
+++ b/tests/bot/exts/moderation/test_silence.py
@@ -1,23 +1,49 @@
+import asyncio
import unittest
+from datetime import datetime, timezone
from unittest import mock
-from unittest.mock import MagicMock, Mock
+from unittest.mock import Mock
+from async_rediscache import RedisSession
from discord import PermissionOverwrite
-from bot.constants import Channels, Emojis, Guild, Roles
-from bot.exts.moderation.silence import Silence, SilenceNotifier
-from tests.helpers import MockBot, MockContext, MockTextChannel
+from bot.constants import Channels, Guild, Roles
+from bot.exts.moderation import silence
+from tests.helpers import MockBot, MockContext, MockTextChannel, autospec
+
+redis_session = None
+redis_loop = asyncio.get_event_loop()
+
+
+def setUpModule(): # noqa: N802
+ """Create and connect to the fakeredis session."""
+ global redis_session
+ redis_session = RedisSession(use_fakeredis=True)
+ redis_loop.run_until_complete(redis_session.connect())
+
+
+def tearDownModule(): # noqa: N802
+ """Close the fakeredis session."""
+ if redis_session:
+ redis_loop.run_until_complete(redis_session.close())
+
+
+# Have to subclass it because builtins can't be patched.
+class PatchedDatetime(datetime):
+ """A datetime object with a mocked now() function."""
+
+ now = mock.create_autospec(datetime, "now")
class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.alert_channel = MockTextChannel()
- self.notifier = SilenceNotifier(self.alert_channel)
+ self.notifier = silence.SilenceNotifier(self.alert_channel)
self.notifier.stop = self.notifier_stop_mock = Mock()
self.notifier.start = self.notifier_start_mock = Mock()
def test_add_channel_adds_channel(self):
- """Channel in FirstHash with current loop is added to internal set."""
+ """Channel is added to `_silenced_channels` with the current loop."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.add_channel(channel)
@@ -35,7 +61,7 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.notifier_start_mock.assert_not_called()
def test_remove_channel_removes_channel(self):
- """Channel in FirstHash is removed from `_silenced_channels`."""
+ """Channel is removed from `_silenced_channels`."""
channel = Mock()
with mock.patch.object(self.notifier, "_silenced_channels") as silenced_channels:
self.notifier.remove_channel(channel)
@@ -59,7 +85,9 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
with self.subTest(current_loop=current_loop):
with mock.patch.object(self.notifier, "_current_loop", new=current_loop):
await self.notifier._notifier()
- self.alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> currently silenced channels: ")
+ self.alert_channel.send.assert_called_once_with(
+ f"<@&{Roles.moderators}> currently silenced channels: "
+ )
self.alert_channel.send.reset_mock()
async def test_notifier_skips_alert(self):
@@ -72,192 +100,403 @@ class SilenceNotifierTests(unittest.IsolatedAsyncioTestCase):
self.alert_channel.send.assert_not_called()
-class SilenceTests(unittest.IsolatedAsyncioTestCase):
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceCogTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the general functionality of the Silence cog."""
+
+ @autospec(silence, "Scheduler", pass_mocks=False)
def setUp(self) -> None:
self.bot = MockBot()
- self.cog = Silence(self.bot)
- self.ctx = MockContext()
- self.cog._verified_role = None
- # Set event so command callbacks can continue.
- self.cog._get_instance_vars_event.set()
+ self.cog = silence.Silence(self.bot)
- async def test_instance_vars_got_guild(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_guild(self):
"""Bot got guild after it became available."""
- await self.cog._get_instance_vars()
- self.bot.wait_until_guild_available.assert_called_once()
+ await self.cog._async_init()
+ self.bot.wait_until_guild_available.assert_awaited_once()
self.bot.get_guild.assert_called_once_with(Guild.id)
- async def test_instance_vars_got_role(self):
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_role(self):
"""Got `Roles.verified` role from guild."""
- await self.cog._get_instance_vars()
guild = self.bot.get_guild()
- guild.get_role.assert_called_once_with(Roles.verified)
+ guild.get_role.side_effect = lambda id_: Mock(id=id_)
- async def test_instance_vars_got_channels(self):
+ await self.cog._async_init()
+ self.assertEqual(self.cog._verified_role.id, Roles.verified)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_got_channels(self):
"""Got channels from bot."""
- await self.cog._get_instance_vars()
- self.bot.get_channel.called_once_with(Channels.mod_alerts)
- self.bot.get_channel.called_once_with(Channels.mod_log)
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ self.assertEqual(self.cog._mod_alerts_channel.id, Channels.mod_alerts)
- @mock.patch("bot.exts.moderation.silence.SilenceNotifier")
- async def test_instance_vars_got_notifier(self, notifier):
+ @autospec(silence, "SilenceNotifier")
+ async def test_async_init_got_notifier(self, notifier):
"""Notifier was started with channel."""
- mod_log = MockTextChannel()
- self.bot.get_channel.side_effect = (None, mod_log)
- await self.cog._get_instance_vars()
- notifier.assert_called_once_with(mod_log)
- self.bot.get_channel.side_effect = None
-
- async def test_silence_sent_correct_discord_message(self):
- """Check if proper message was sent when called with duration in channel with previous state."""
+ self.bot.get_channel.side_effect = lambda id_: MockTextChannel(id=id_)
+
+ await self.cog._async_init()
+ notifier.assert_called_once_with(MockTextChannel(id=Channels.mod_log))
+ self.assertEqual(self.cog.notifier, notifier.return_value)
+
+ @autospec(silence, "SilenceNotifier", pass_mocks=False)
+ async def test_async_init_rescheduled(self):
+ """`_reschedule_` coroutine was awaited."""
+ self.cog._reschedule = mock.create_autospec(self.cog._reschedule)
+ await self.cog._async_init()
+ self.cog._reschedule.assert_awaited_once_with()
+
+ def test_cog_unload_cancelled_tasks(self):
+ """The init task was cancelled."""
+ self.cog._init_task = asyncio.Future()
+ self.cog.cog_unload()
+
+ # It's too annoying to test cancel_all since it's a done callback and wrapped in a lambda.
+ self.assertTrue(self.cog._init_task.cancelled())
+
+ @autospec("discord.ext.commands", "has_any_role")
+ @mock.patch.object(silence, "MODERATION_ROLES", new=(1, 2, 3))
+ async def test_cog_check(self, role_check):
+ """Role check was called with `MODERATION_ROLES`"""
+ ctx = MockContext()
+ role_check.return_value.predicate = mock.AsyncMock()
+
+ await self.cog.cog_check(ctx)
+ role_check.assert_called_once_with(*(1, 2, 3))
+ role_check.return_value.predicate.assert_awaited_once_with(ctx)
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class RescheduleTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the rescheduling of cached unsilences."""
+
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self):
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._unsilence_wrapper = mock.create_autospec(self.cog._unsilence_wrapper)
+
+ with mock.patch.object(self.cog, "_reschedule", autospec=True):
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ async def test_skipped_missing_channel(self):
+ """Did nothing because the channel couldn't be retrieved."""
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (123, 1), (123, 10000000000)]
+ self.bot.get_channel.return_value = None
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_added_permanent_to_notifier(self):
+ """Permanently silenced channels were added to the notifier."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, -1), (456, -1)]
+
+ await self.cog._reschedule()
+
+ self.cog.notifier.add_channel.assert_any_call(channels[0])
+ self.cog.notifier.add_channel.assert_any_call(channels[1])
+
+ self.cog._unsilence_wrapper.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ async def test_unsilenced_expired(self):
+ """Unsilenced expired silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 100), (456, 200)]
+
+ await self.cog._reschedule()
+
+ self.cog._unsilence_wrapper.assert_any_call(channels[0])
+ self.cog._unsilence_wrapper.assert_any_call(channels[1])
+
+ self.cog.notifier.add_channel.assert_not_called()
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+ @mock.patch.object(silence, "datetime", new=PatchedDatetime)
+ async def test_rescheduled_active(self):
+ """Rescheduled active silences."""
+ channels = [MockTextChannel(id=123), MockTextChannel(id=456)]
+ self.bot.get_channel.side_effect = channels
+ self.cog.unsilence_timestamps.items.return_value = [(123, 2000), (456, 3000)]
+ silence.datetime.now.return_value = datetime.fromtimestamp(1000, tz=timezone.utc)
+
+ self.cog._unsilence_wrapper = mock.MagicMock()
+ unsilence_return = self.cog._unsilence_wrapper.return_value
+
+ await self.cog._reschedule()
+
+ # Yuck.
+ calls = [mock.call(1000, 123, unsilence_return), mock.call(2000, 456, unsilence_return)]
+ self.cog.scheduler.schedule_later.assert_has_calls(calls)
+
+ unsilence_calls = [mock.call(channel) for channel in channels]
+ self.cog._unsilence_wrapper.assert_has_calls(unsilence_calls)
+
+ self.cog.notifier.add_channel.assert_not_called()
+
+
+@autospec(silence.Silence, "previous_overwrites", "unsilence_timestamps", pass_mocks=False)
+class SilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the silence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot()
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ # Avoid unawaited coroutine warnings.
+ self.cog.scheduler.schedule_later.side_effect = lambda delay, task_id, coro: coro.close()
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=True, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
test_cases = (
- (0.0001, f"{Emojis.check_mark} silenced current channel for 0.0001 minute(s).", True,),
- (None, f"{Emojis.check_mark} silenced current channel indefinitely.", True,),
- (5, f"{Emojis.cross_mark} current channel is already silenced.", False,),
+ (0.0001, silence.MSG_SILENCE_SUCCESS.format(duration=0.0001), True,),
+ (None, silence.MSG_SILENCE_PERMANENT, True,),
+ (5, silence.MSG_SILENCE_FAIL, False,),
)
- for duration, result_message, _silence_patch_return in test_cases:
- with self.subTest(
- silence_duration=duration,
- result_message=result_message,
- starting_unsilenced_state=_silence_patch_return
- ):
- with mock.patch.object(self.cog, "_silence", return_value=_silence_patch_return):
- await self.cog.silence(self.cog, self.ctx, duration)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_unsilence_sent_correct_discord_message(self):
- """Check if proper message was sent when unsilencing channel."""
- test_cases = (
- (True, f"{Emojis.check_mark} unsilenced current channel."),
- (False, f"{Emojis.cross_mark} current channel was not silenced.")
+ for duration, message, was_silenced in test_cases:
+ ctx = MockContext()
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=was_silenced):
+ with self.subTest(was_silenced=was_silenced, message=message, duration=duration):
+ await self.cog.silence.callback(self.cog, ctx, duration)
+ ctx.send.assert_called_once_with(message)
+
+ async def test_skipped_already_silenced(self):
+ """Permissions were not set and `False` was returned for an already silenced channel."""
+ subtests = (
+ (False, PermissionOverwrite(send_messages=False, add_reactions=False)),
+ (True, PermissionOverwrite(send_messages=True, add_reactions=True)),
+ (True, PermissionOverwrite(send_messages=False, add_reactions=False)),
)
- for _unsilence_patch_return, result_message in test_cases:
- with self.subTest(
- starting_silenced_state=_unsilence_patch_return,
- result_message=result_message
- ):
- with mock.patch.object(self.cog, "_unsilence", return_value=_unsilence_patch_return):
- await self.cog.unsilence(self.cog, self.ctx)
- self.ctx.send.assert_called_once_with(result_message)
- self.ctx.reset_mock()
-
- async def test_silence_private_for_false(self):
- """Permissions are not set and `False` is returned in an already silenced channel."""
- perm_overwrite = Mock(send_messages=False)
- channel = Mock(overwrites_for=Mock(return_value=perm_overwrite))
-
- self.assertFalse(await self.cog._silence(channel, True, None))
- channel.set_permissions.assert_not_called()
- async def test_silence_private_silenced_channel(self):
- """Channel had `send_message` permissions revoked."""
- channel = MockTextChannel()
- self.assertTrue(await self.cog._silence(channel, False, None))
- channel.set_permissions.assert_called_once()
- self.assertFalse(channel.set_permissions.call_args.kwargs['send_messages'])
+ for contains, overwrite in subtests:
+ with self.subTest(contains=contains, overwrite=overwrite):
+ self.cog.scheduler.__contains__.return_value = contains
+ channel = MockTextChannel()
+ channel.overwrites_for.return_value = overwrite
+
+ self.assertFalse(await self.cog._set_silence_overwrites(channel))
+ channel.set_permissions.assert_not_called()
+
+ async def test_silenced_channel(self):
+ """Channel had `send_message` and `add_reactions` permissions revoked for verified role."""
+ self.assertTrue(await self.cog._set_silence_overwrites(self.channel))
+ self.assertFalse(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite
+ )
- async def test_silence_private_preserves_permissions(self):
- """Previous permissions were preserved when channel was silenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite()
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._silence(channel, False, None)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- async def test_silence_private_notifier(self):
- """Channel should be added to notifier with `persistent` set to `True`, and the other way around."""
- channel = MockTextChannel()
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=True):
- await self.cog._silence(channel, True, None)
- self.cog.notifier.add_channel.assert_called_once()
-
- with mock.patch.object(self.cog, "notifier", create=True):
- with self.subTest(persistent=False):
- await self.cog._silence(channel, False, None)
- self.cog.notifier.add_channel.assert_not_called()
-
- async def test_silence_private_added_muted_channel(self):
- """Channel was added to `muted_channels` on silence."""
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed."""
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._set_silence_overwrites(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove 'send_messages' & 'add_reactions' keys because they were changed by the method.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
+
+ async def test_temp_not_added_to_notifier(self):
+ """Channel was not added to notifier if a duration was set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_indefinite_added_to_notifier(self):
+ """Channel was added to notifier if a duration was not set for the silence."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=True):
+ await self.cog.silence.callback(self.cog, MockContext(), None)
+ self.cog.notifier.add_channel.assert_called_once()
+
+ async def test_silenced_not_added_to_notifier(self):
+ """Channel was not added to the notifier if it was already silenced."""
+ with mock.patch.object(self.cog, "_set_silence_overwrites", return_value=False):
+ await self.cog.silence.callback(self.cog, MockContext(), 15)
+ self.cog.notifier.add_channel.assert_not_called()
+
+ async def test_cached_previous_overwrites(self):
+ """Channel's previous overwrites were cached."""
+ overwrite_json = '{"send_messages": true, "add_reactions": false}'
+ await self.cog._set_silence_overwrites(self.channel)
+ self.cog.previous_overwrites.set.assert_called_once_with(self.channel.id, overwrite_json)
+
+ @autospec(silence, "datetime")
+ async def test_cached_unsilence_time(self, datetime_mock):
+ """The UTC POSIX timestamp for the unsilence was cached."""
+ now_timestamp = 100
+ duration = 15
+ timestamp = now_timestamp + duration * 60
+ datetime_mock.now.return_value = datetime.fromtimestamp(now_timestamp, tz=timezone.utc)
+
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, duration)
+
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, timestamp)
+ datetime_mock.now.assert_called_once_with(tz=timezone.utc) # Ensure it's using an aware dt.
+
+ async def test_cached_indefinite_time(self):
+ """A value of -1 was cached for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.unsilence_timestamps.set.assert_awaited_once_with(ctx.channel.id, -1)
+
+ async def test_scheduled_task(self):
+ """An unsilence task was scheduled."""
+ ctx = MockContext(channel=self.channel, invoke=mock.MagicMock())
+
+ await self.cog.silence.callback(self.cog, ctx, 5)
+
+ args = (300, ctx.channel.id, ctx.invoke.return_value)
+ self.cog.scheduler.schedule_later.assert_called_once_with(*args)
+ ctx.invoke.assert_called_once_with(self.cog.unsilence)
+
+ async def test_permanent_not_scheduled(self):
+ """A task was not scheduled for a permanent silence."""
+ ctx = MockContext(channel=self.channel)
+ await self.cog.silence.callback(self.cog, ctx, None)
+ self.cog.scheduler.schedule_later.assert_not_called()
+
+
+@autospec(silence.Silence, "unsilence_timestamps", pass_mocks=False)
+class UnsilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for the unsilence command and its related helper methods."""
+
+ @autospec(silence.Silence, "_reschedule", pass_mocks=False)
+ @autospec(silence, "Scheduler", "SilenceNotifier", pass_mocks=False)
+ def setUp(self) -> None:
+ self.bot = MockBot(get_channel=lambda _: MockTextChannel())
+ self.cog = silence.Silence(self.bot)
+ self.cog._init_task = asyncio.Future()
+ self.cog._init_task.set_result(None)
+
+ overwrites_cache = mock.create_autospec(self.cog.previous_overwrites, spec_set=True)
+ self.cog.previous_overwrites = overwrites_cache
+
+ asyncio.run(self.cog._async_init()) # Populate instance attributes.
+
+ self.cog.scheduler.__contains__.return_value = True
+ overwrites_cache.get.return_value = '{"send_messages": true, "add_reactions": false}'
+ self.channel = MockTextChannel()
+ self.overwrite = PermissionOverwrite(stream=True, send_messages=False, add_reactions=False)
+ self.channel.overwrites_for.return_value = self.overwrite
+
+ async def test_sent_correct_message(self):
+ """Appropriate failure/success message was sent by the command."""
+ unsilenced_overwrite = PermissionOverwrite(send_messages=True, add_reactions=True)
+ test_cases = (
+ (True, silence.MSG_UNSILENCE_SUCCESS, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_FAIL, unsilenced_overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, self.overwrite),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(send_messages=False)),
+ (False, silence.MSG_UNSILENCE_MANUAL, PermissionOverwrite(add_reactions=False)),
+ )
+ for was_unsilenced, message, overwrite in test_cases:
+ ctx = MockContext()
+ with self.subTest(was_unsilenced=was_unsilenced, message=message, overwrite=overwrite):
+ with mock.patch.object(self.cog, "_unsilence", return_value=was_unsilenced):
+ ctx.channel.overwrites_for.return_value = overwrite
+ await self.cog.unsilence.callback(self.cog, ctx)
+ ctx.channel.send.assert_called_once_with(message)
+
+ async def test_skipped_already_unsilenced(self):
+ """Permissions were not set and `False` was returned for an already unsilenced channel."""
+ self.cog.scheduler.__contains__.return_value = False
+ self.cog.previous_overwrites.get.return_value = None
channel = MockTextChannel()
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._silence(channel, False, None)
- muted_channels.add.assert_called_once_with(channel)
- async def test_unsilence_private_for_false(self):
- """Permissions are not set and `False` is returned in an unsilenced channel."""
- channel = Mock()
self.assertFalse(await self.cog._unsilence(channel))
channel.set_permissions.assert_not_called()
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_unsilenced_channel(self, _):
- """Channel had `send_message` permissions restored"""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- self.assertTrue(await self.cog._unsilence(channel))
- channel.set_permissions.assert_called_once()
- self.assertIsNone(channel.set_permissions.call_args.kwargs['send_messages'])
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_notifier(self, notifier):
- """Channel was removed from `notifier` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- await self.cog._unsilence(channel)
- notifier.remove_channel.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_removed_muted_channel(self, _):
- """Channel was removed from `muted_channels` on unsilence."""
- perm_overwrite = MagicMock(send_messages=False)
- channel = MockTextChannel(overwrites_for=Mock(return_value=perm_overwrite))
- with mock.patch.object(self.cog, "muted_channels") as muted_channels:
- await self.cog._unsilence(channel)
- muted_channels.discard.assert_called_once_with(channel)
-
- @mock.patch.object(Silence, "notifier", create=True)
- async def test_unsilence_private_preserves_permissions(self, _):
- """Previous permissions were preserved when channel was unsilenced."""
- channel = MockTextChannel()
- # Set up mock channel permission state.
- mock_permissions = PermissionOverwrite(send_messages=False)
- mock_permissions_dict = dict(mock_permissions)
- channel.overwrites_for.return_value = mock_permissions
- await self.cog._unsilence(channel)
- new_permissions = channel.set_permissions.call_args.kwargs
- # Remove 'send_messages' key because it got changed in the method.
- del new_permissions['send_messages']
- del mock_permissions_dict['send_messages']
- self.assertDictEqual(mock_permissions_dict, new_permissions)
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- @mock.patch.object(Silence, "_mod_alerts_channel", create=True)
- def test_cog_unload_starts_task(self, alert_channel, asyncio_mock):
- """Task for sending an alert was created with present `muted_channels`."""
- with mock.patch.object(self.cog, "muted_channels"):
- self.cog.cog_unload()
- alert_channel.send.assert_called_once_with(f"<@&{Roles.moderators}> channels left silenced on cog unload: ")
- asyncio_mock.create_task.assert_called_once_with(alert_channel.send())
-
- @mock.patch("bot.exts.moderation.silence.asyncio")
- def test_cog_unload_skips_task_start(self, asyncio_mock):
- """No task created with no channels."""
- self.cog.cog_unload()
- asyncio_mock.create_task.assert_not_called()
+ async def test_restored_overwrites(self):
+ """Channel's `send_message` and `add_reactions` overwrites were restored."""
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
- @mock.patch("discord.ext.commands.has_any_role")
- @mock.patch("bot.exts.moderation.silence.MODERATION_ROLES", new=(1, 2, 3))
- async def test_cog_check(self, role_check):
- """Role check is called with `MODERATION_ROLES`"""
- role_check.return_value.predicate = mock.AsyncMock()
- await self.cog.cog_check(self.ctx)
- role_check.assert_called_once_with(*(1, 2, 3))
- role_check.return_value.predicate.assert_awaited_once_with(self.ctx)
+ # Recall that these values are determined by the fixture.
+ self.assertTrue(self.overwrite.send_messages)
+ self.assertFalse(self.overwrite.add_reactions)
+
+ async def test_cache_miss_used_default_overwrites(self):
+ """Both overwrites were set to None due previous values not being found in the cache."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.channel.set_permissions.assert_awaited_once_with(
+ self.cog._verified_role,
+ overwrite=self.overwrite,
+ )
+
+ self.assertIsNone(self.overwrite.send_messages)
+ self.assertIsNone(self.overwrite.add_reactions)
+
+ async def test_cache_miss_sent_mod_alert(self):
+ """A message was sent to the mod alerts channel."""
+ self.cog.previous_overwrites.get.return_value = None
+
+ await self.cog._unsilence(self.channel)
+ self.cog._mod_alerts_channel.send.assert_awaited_once()
+
+ async def test_removed_notifier(self):
+ """Channel was removed from `notifier`."""
+ await self.cog._unsilence(self.channel)
+ self.cog.notifier.remove_channel.assert_called_once_with(self.channel)
+
+ async def test_deleted_cached_overwrite(self):
+ """Channel was deleted from the overwrites cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.previous_overwrites.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_deleted_cached_time(self):
+ """Channel was deleted from the timestamp cache."""
+ await self.cog._unsilence(self.channel)
+ self.cog.unsilence_timestamps.delete.assert_awaited_once_with(self.channel.id)
+
+ async def test_cancelled_task(self):
+ """The scheduled unsilence task should be cancelled."""
+ await self.cog._unsilence(self.channel)
+ self.cog.scheduler.cancel.assert_called_once_with(self.channel.id)
+
+ async def test_preserved_other_overwrites(self):
+ """Channel's other unrelated overwrites were not changed, including cache misses."""
+ for overwrite_json in ('{"send_messages": true, "add_reactions": null}', None):
+ with self.subTest(overwrite_json=overwrite_json):
+ self.cog.previous_overwrites.get.return_value = overwrite_json
+
+ prev_overwrite_dict = dict(self.overwrite)
+ await self.cog._unsilence(self.channel)
+ new_overwrite_dict = dict(self.overwrite)
+
+ # Remove these keys because they were modified by the unsilence.
+ del prev_overwrite_dict['send_messages']
+ del prev_overwrite_dict['add_reactions']
+ del new_overwrite_dict['send_messages']
+ del new_overwrite_dict['add_reactions']
+
+ self.assertDictEqual(prev_overwrite_dict, new_overwrite_dict)
diff --git a/tests/helpers.py b/tests/helpers.py
index e47fdf28f..870f66197 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -5,7 +5,7 @@ import itertools
import logging
import unittest.mock
from asyncio import AbstractEventLoop
-from typing import Callable, Iterable, Optional
+from typing import Iterable, Optional
import discord
from aiohttp import ClientSession
@@ -14,6 +14,7 @@ from discord.ext.commands import Context
from bot.api import APIClient
from bot.async_stats import AsyncStatsClient
from bot.bot import Bot
+from tests._autospec import autospec # noqa: F401 other modules import it via this module
for logger in logging.Logger.manager.loggerDict.values():
@@ -26,24 +27,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def autospec(target, *attributes: str, **kwargs) -> Callable:
- """Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True."""
- # Caller's kwargs should take priority and overwrite the defaults.
- kwargs = {'spec_set': True, 'autospec': True, **kwargs}
-
- # Import the target if it's a string.
- # This is to support both object and string targets like patch.multiple.
- if type(target) is str:
- target = unittest.mock._importer(target)
-
- def decorator(func):
- for attribute in attributes:
- patcher = unittest.mock.patch.object(target, attribute, **kwargs)
- func = patcher(func)
- return func
- return decorator
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.